From ad00cc8ca1b5114f04e53de4d6bda234061cda1e Mon Sep 17 00:00:00 2001 From: hejl Date: Sat, 2 Nov 2024 14:00:27 +0800 Subject: [PATCH 1/6] support multi select --- api/core/agent/base_agent_runner.py | 4 +- api/core/tools/entities/tool_entities.py | 6 +- .../provider/builtin/firecrawl/tools/crawl.py | 2 +- .../builtin/firecrawl/tools/crawl.yaml | 35 +++-- .../builtin/firecrawl/tools/scrape.py | 2 +- .../builtin/firecrawl/tools/scrape.yaml | 35 +++-- .../tools/provider/builtin_tool_provider.py | 65 --------- api/core/tools/provider/tool_provider.py | 12 ++ api/core/tools/tool_manager.py | 7 + .../core/tools/test_tool_parameter_type.py | 5 + web/app/components/base/select/index.tsx | 126 +++++++++++++++++- .../model-provider-page/declarations.ts | 1 + .../model-provider-page/model-modal/Form.tsx | 42 +++++- .../nodes/tool/components/input-var-list.tsx | 19 ++- 14 files changed, 268 insertions(+), 93 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 507455c176..58127de697 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -172,7 +172,7 @@ class BaseAgentRunner(AppRunner): }: continue enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: + if parameter.type in {ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.MULTI_SELECT}: enum = [option.value for option in parameter.options] message_tool.parameters["properties"][parameter.name] = { @@ -263,7 +263,7 @@ class BaseAgentRunner(AppRunner): }: continue enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: + if parameter.type in {ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.MULTI_SELECT}: enum = [option.value for option in parameter.options] prompt_tool.parameters["properties"][parameter.name] = { diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d8637fd2cb..07e9dc6248 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -142,12 +142,13 @@ class ToolParameter(BaseModel): NUMBER = "number" BOOLEAN = "boolean" SELECT = "select" + MULTI_SELECT = "multi-select" SECRET_INPUT = "secret-input" FILE = "file" FILES = "files" # deprecated, should not use. - SYSTEM_FILES = "systme-files" + SYSTEM_FILES = "system-files" def as_normal_type(self): if self in { @@ -155,6 +156,8 @@ class ToolParameter(BaseModel): ToolParameter.ToolParameterType.SELECT, }: return "string" + elif self == ToolParameter.ToolParameterType.MULTI_SELECT: + return "array" return self.value def cast_value(self, value: Any, /): @@ -198,6 +201,7 @@ class ToolParameter(BaseModel): ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILE | ToolParameter.ToolParameterType.FILES + | ToolParameter.ToolParameterType.MULTI_SELECT ): return value case _: diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 15ab510c6c..18b9671a34 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -29,7 +29,7 @@ class CrawlTool(BuiltinTool): payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) payload["webhook"] = tool_parameters.get("webhook") - scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["formats"] = tool_parameters.get("formats") scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0d7dbcac20..0d51f0fed4 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -123,18 +123,35 @@ parameters: form: form ############## Scrape Options ####################### - name: formats - type: string + type: multi-select label: - en_US: Formats + en_US: Output Formats zh_Hans: 结果的格式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 + options: + - value: markdown + label: + en_US: markdown + - value: html + label: + en_US: html + - value: rawHtml + label: + en_US: rawHtml + - value: links + label: + en_US: links + - value: screenshot + label: + en_US: screenshot + - value: extract + label: + en_US: extract + - value: screenshot@fullPage + label: + en_US: screenshot@fullPage human_description: - en_US: | - Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot - zh_Hans: | - 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + en_US: Formats to include in the output. Multiple selections possible. + zh_Hans: 输出中应包含的格式。可多选。 form: form - name: headers type: string diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index f00a9b31ce..54e0e62299 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -18,7 +18,7 @@ class ScrapeTool(BuiltinTool): payload = {} extract = {} - payload["formats"] = get_array_params(tool_parameters, "formats") + payload["formats"] = tool_parameters.get("formats") payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) payload["includeTags"] = get_array_params(tool_parameters, "includeTags") payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 8f1f1348a4..472cf6eae8 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -23,18 +23,35 @@ parameters: form: llm ############## Payload ####################### - name: formats - type: string + type: multi-select label: - en_US: Formats + en_US: Output Formats zh_Hans: 结果的格式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 + options: + - value: markdown + label: + en_US: markdown + - value: html + label: + en_US: html + - value: rawHtml + label: + en_US: rawHtml + - value: links + label: + en_US: links + - value: screenshot + label: + en_US: screenshot + - value: extract + label: + en_US: extract + - value: screenshot@fullPage + label: + en_US: screenshot@fullPage human_description: - en_US: | - Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage - zh_Hans: | - 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + en_US: Formats to include in the output. Multiple selections possible. + zh_Hans: 输出中应包含的格式。可多选。 form: form - name: onlyMainContent type: boolean diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3b..93cedd8716 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -7,7 +7,6 @@ from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredent from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolNotFoundError, - ToolParameterValidationError, ToolProviderNotFoundError, ) from core.tools.provider.tool_provider import ToolProviderController @@ -146,70 +145,6 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.identity.tags or [] - def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: - """ - validate the parameters of the tool and set the default value if needed - - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool - """ - tool_parameters_schema = self.get_parameters(tool_name) - - tool_parameters_need_to_validate: dict[str, ToolParameter] = {} - for parameter in tool_parameters_schema: - tool_parameters_need_to_validate[parameter.name] = parameter - - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") - - # check type - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") - - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") - - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" - ) - - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" - ) - - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") - - elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") - - options = parameter_schema.options - if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") - - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") - - tool_parameters_need_to_validate.pop(parameter) - - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") - - # the parameter is not set currently, set the default value if needed - if parameter_schema.default is not None: - default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value - def validate_credentials(self, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a11562..5253adcdd9 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -117,6 +117,18 @@ class ToolProviderController(BaseModel, ABC): if tool_parameters[parameter] not in [x.value for x in options]: raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + elif parameter_schema.type == ToolParameter.ToolParameterType.MULTI_SELECT: + if not isinstance(tool_parameters[parameter], list): + raise ToolParameterValidationError(f"parameter {parameter} should be list") + + options = parameter_schema.options + if not isinstance(options, list): + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + + for item in tool_parameters[parameter]: + if item not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 6abe0a9cba..0b34f3757d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -220,6 +220,13 @@ class ToolManager: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" ) + elif parameter_rule.type == ToolParameter.ToolParameterType.MULTI_SELECT and parameter_value is not None: + options = [x.value for x in parameter_rule.options] + for value in parameter_value: + if value not in options: + raise ValueError( + f"tool parameter {parameter_rule.name} value {value} not in options {options}" + ) return parameter_rule.type.cast_value(parameter_value) diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py index 8a41678267..067bc52c26 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py @@ -5,6 +5,7 @@ def test_get_parameter_type(): assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string" assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string" assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.MULTI_SELECT.as_normal_type() == "array" assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean" assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number" assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file" @@ -30,6 +31,10 @@ def test_cast_parameter_by_type(): assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0" assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == "" + # multi select + assert ToolParameter.ToolParameterType.MULTI_SELECT.cast_value(["test", "test2"]) == ["test", "test2"] + assert ToolParameter.ToolParameterType.MULTI_SELECT.cast_value([2, 1, 1.0]) == [2, 1, 1.0] + # boolean true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ee5cee977b..6e4eb7feb1 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -29,7 +29,7 @@ export type Item = { export type ISelectProps = { className?: string wrapperClassName?: string - renderTrigger?: (value: Item | null) => JSX.Element | null + renderTrigger?: (value: Item | Item[] | null) => JSX.Element | null items?: Item[] defaultValue?: number | string disabled?: boolean @@ -378,5 +378,127 @@ const PortalSelect: FC = ({ ) } -export { SimpleSelect, PortalSelect } + +export type MultiSelectProps = Omit & { + onSelect: (values: Item[]) => void + selectedValues?: (number | string)[] +} + +const MultiSelect: FC = ({ + className, + wrapperClassName = '', + renderTrigger, + items = defaultItems, + selectedValues = [], + disabled = false, + onSelect, + placeholder, + optionWrapClassName, + optionClassName, + hideChecked, + renderOption, +}) => { + const { t } = useTranslation() + const localPlaceholder = placeholder || t('common.placeholder.select') + + const [selectedItems, setSelectedItems] = useState([]) + useEffect(() => { + const defaultSelected = items.filter((item: Item) => selectedValues.includes(item.value)) + setSelectedItems(defaultSelected) + }, [selectedValues, items]) + + const handleSelect = (newItems: Item[]) => { + if (!disabled) { + setSelectedItems(newItems) + onSelect(newItems) + } + } + + const removeItem = (item: Item) => { + const newSelectedItems = selectedItems.filter(i => i.value !== item.value) + setSelectedItems(newSelectedItems) + onSelect(newSelectedItems) + } + + return ( + +
+ {renderTrigger && {renderTrigger(selectedItems)}} + {!renderTrigger && ( + + {selectedItems.length > 0 + ? (selectedItems.map(item => ( + + {item.name} + { + e.stopPropagation() + removeItem(item) + }} + /> + + )) + ) + : ( + {localPlaceholder} + )} + + + + )} + + {!disabled && ( + + + {items.map((item: Item) => ( + + classNames( + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : ''}`, + optionClassName, + ) + } + value={item} + > + {({ selected }) => ( + <> + {renderOption + ? renderOption({ item, selected }) + : ( + <> + {item.name} + {selected && !hideChecked && ( + + + )} + + )} + + )} + + ))} + + + )} +
+
+ ) +} +export { SimpleSelect, PortalSelect, MultiSelect } export default React.memo(Select) diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 8a84376bea..9342182dd4 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -11,6 +11,7 @@ export enum FormTypeEnum { textNumber = 'number-input', secretInput = 'secret-input', select = 'select', + multiSelect = 'multi-select', radio = 'radio', boolean = 'boolean', files = 'files', diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index c0a7be68a6..d2a8c4e11c 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -14,7 +14,7 @@ import { FormTypeEnum } from '../declarations' import { useLanguage } from '../hooks' import Input from './Input' import cn from '@/utils/classnames' -import { SimpleSelect } from '@/app/components/base/select' +import { MultiSelect, SimpleSelect } from '@/app/components/base/select' import Tooltip from '@/app/components/base/tooltip' import Radio from '@/app/components/base/radio' type FormProps = { @@ -53,7 +53,7 @@ const Form: FC = ({ const language = useLanguage() const [changeKey, setChangeKey] = useState('') - const handleFormChange = (key: string, val: string | boolean) => { + const handleFormChange = (key: string, val: string | boolean | any[]) => { if (isEditMode && (key === '__model_type' || key === '__model_name')) return @@ -220,6 +220,44 @@ const Form: FC = ({ ) } + if (formSchema.type === 'multi-select') { + const { + options, + variable, + label, + show_on, + required, + placeholder, + } = formSchema as CredentialFormSchemaSelect + + if (show_on.length && !show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) + return null + + const handleMultiSelect = (selectedItems: any[]) => { + handleFormChange(variable, selectedItems.map(item => item.value)) + } + + return ( +
+
+ {label[language] || label.en_US} + {required && *} + {tooltipContent} +
+ ({ value: item.value, name: item.label[language] || item.label.en_US }))} + onSelect={handleMultiSelect} + placeholder={placeholder?.[language] || placeholder?.en_US} + /> + {fieldMoreInfo?.(formSchema)} + {validating && changeKey === variable && } +
+ ) + } + if (formSchema.type === 'boolean') { const { variable, diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index e47082f4b7..fbf5d9a3f4 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -50,6 +50,8 @@ const InputVarList: FC = ({ return 'Files' else if (type === FormTypeEnum.select) return 'Options' + else if (type === FormTypeEnum.multiSelect) + return 'Array' else return 'String' } @@ -139,9 +141,10 @@ const InputVarList: FC = ({ const varInput = value[variable] const isNumber = type === FormTypeEnum.textNumber const isSelect = type === FormTypeEnum.select + const isMultiSelect = type === FormTypeEnum.multiSelect const isFile = type === FormTypeEnum.file const isFileArray = type === FormTypeEnum.files - const isString = !isNumber && !isSelect && !isFile && !isFileArray + const isString = !isNumber && !isSelect && !isMultiSelect && !isFile && !isFileArray return (
@@ -177,6 +180,20 @@ const InputVarList: FC = ({ schema={schema} /> )} + { + isMultiSelect && ( + varPayload.type === VarType.arrayString || varPayload.type === VarType.arrayNumber} + /> + ) + } {isFile && ( Date: Sat, 2 Nov 2024 21:33:29 +0800 Subject: [PATCH 2/6] fix CI --- api/core/tools/tool_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0b34f3757d..6831b7db90 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -224,9 +224,7 @@ class ToolManager: options = [x.value for x in parameter_rule.options] for value in parameter_value: if value not in options: - raise ValueError( - f"tool parameter {parameter_rule.name} value {value} not in options {options}" - ) + raise ValueError(f"tool parameter {parameter_rule.name} value {value} not in options {options}") return parameter_rule.type.cast_value(parameter_value) From a7b175ce09e997bebb75293b3610301f52a20896 Mon Sep 17 00:00:00 2001 From: hejl Date: Sat, 2 Nov 2024 21:42:57 +0800 Subject: [PATCH 3/6] fix re-run CI --- api/core/tools/tool_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 6831b7db90..329724bb1c 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -234,6 +234,7 @@ class ToolManager: ) -> Tool: """ get the agent tool runtime + """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, From 1c252e312796bd5dda5bb55be62cdeea8cd22b91 Mon Sep 17 00:00:00 2001 From: hejl Date: Sat, 2 Nov 2024 21:43:04 +0800 Subject: [PATCH 4/6] fix re-run CI --- api/core/tools/tool_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 329724bb1c..0267076f60 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -233,8 +233,7 @@ class ToolManager: cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER ) -> Tool: """ - get the agent tool runtime - + get the agent tool runtim """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, From 024579febca45c52655f4eb772a069fa80f5b71a Mon Sep 17 00:00:00 2001 From: hejl Date: Tue, 5 Nov 2024 10:22:02 +0800 Subject: [PATCH 5/6] add xAI --- .../model_providers/x/__init__.py | 0 .../model_providers/x/_assets/x-ai-logo.svg | 1 + .../model_providers/x/llm/__init__.py | 0 .../model_providers/x/llm/grok-beta.yaml | 63 +++++++++++++++++++ .../model_providers/x/llm/llm.py | 35 +++++++++++ api/core/model_runtime/model_providers/x/x.py | 25 ++++++++ .../model_runtime/model_providers/x/x.yaml | 38 +++++++++++ 7 files changed, 162 insertions(+) create mode 100644 api/core/model_runtime/model_providers/x/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg create mode 100644 api/core/model_runtime/model_providers/x/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/llm/grok-beta.yaml create mode 100644 api/core/model_runtime/model_providers/x/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/x/x.py create mode 100644 api/core/model_runtime/model_providers/x/x.yaml diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 0000000000..f8b745cb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 0000000000..542149577b --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 0000000000..5c0680a2e7 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,35 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["api_base"]) / "v1") + credentials["mode"] = LLMMode.CHAT.value diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 0000000000..fc4ed822b5 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 0000000000..56a401c080 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: api_base + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base From a16a1fffb1172c4ef656eeab3ea8a36004b2643c Mon Sep 17 00:00:00 2001 From: hejl Date: Tue, 5 Nov 2024 11:51:40 +0800 Subject: [PATCH 6/6] Revert "add xAI" This reverts commit 024579febca45c52655f4eb772a069fa80f5b71a. --- .../model_providers/x/__init__.py | 0 .../model_providers/x/_assets/x-ai-logo.svg | 1 - .../model_providers/x/llm/__init__.py | 0 .../model_providers/x/llm/grok-beta.yaml | 63 ------------------- .../model_providers/x/llm/llm.py | 35 ----------- api/core/model_runtime/model_providers/x/x.py | 25 -------- .../model_runtime/model_providers/x/x.yaml | 38 ----------- 7 files changed, 162 deletions(-) delete mode 100644 api/core/model_runtime/model_providers/x/__init__.py delete mode 100644 api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg delete mode 100644 api/core/model_runtime/model_providers/x/llm/__init__.py delete mode 100644 api/core/model_runtime/model_providers/x/llm/grok-beta.yaml delete mode 100644 api/core/model_runtime/model_providers/x/llm/llm.py delete mode 100644 api/core/model_runtime/model_providers/x/x.py delete mode 100644 api/core/model_runtime/model_providers/x/x.yaml diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg deleted file mode 100644 index f8b745cb13..0000000000 --- a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml deleted file mode 100644 index 542149577b..0000000000 --- a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml +++ /dev/null @@ -1,63 +0,0 @@ -model: grok-beta -label: - en_US: Grok beta -model_type: llm -features: - - multi-tool-call -model_properties: - mode: chat - context_size: 131072 -parameter_rules: - - name: temperature - label: - en_US: "Temperature" - zh_Hans: "采样温度" - type: float - default: 0.7 - min: 0.0 - max: 2.0 - precision: 1 - required: true - help: - en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." - zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" - - - name: top_p - label: - en_US: "Top P" - zh_Hans: "Top P" - type: float - default: 0.7 - min: 0.0 - max: 1.0 - precision: 1 - required: true - help: - en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." - zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" - - - name: frequency_penalty - use_template: frequency_penalty - label: - en_US: "Frequency Penalty" - zh_Hans: "频率惩罚" - type: float - default: 0 - min: 0 - max: 2.0 - precision: 1 - required: false - help: - en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." - zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" - - - name: user - use_template: text - label: - en_US: "User" - zh_Hans: "用户" - type: string - required: false - help: - en_US: "Used to track and differentiate conversation requests from different users." - zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py deleted file mode 100644 index 5c0680a2e7..0000000000 --- a/api/core/model_runtime/model_providers/x/llm/llm.py +++ /dev/null @@ -1,35 +0,0 @@ -from collections.abc import Generator -from typing import Optional, Union - -from yarl import URL - -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult -from core.model_runtime.entities.message_entities import ( - PromptMessage, - PromptMessageTool, -) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel - - -class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) - - def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) - - @staticmethod - def _add_custom_parameters(credentials) -> None: - credentials["endpoint_url"] = str(URL(credentials["api_base"]) / "v1") - credentials["mode"] = LLMMode.CHAT.value diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py deleted file mode 100644 index fc4ed822b5..0000000000 --- a/api/core/model_runtime/model_providers/x/x.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging - -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class XAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - """ - Validate provider credentials - if validate failed, raise exception - - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - """ - try: - model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="grok-beta", credentials=credentials) - except CredentialsValidateFailedError as ex: - raise ex - except Exception as ex: - logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") - raise ex \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml deleted file mode 100644 index 56a401c080..0000000000 --- a/api/core/model_runtime/model_providers/x/x.yaml +++ /dev/null @@ -1,38 +0,0 @@ -provider: x -label: - en_US: xAI -description: - en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. -icon_small: - en_US: x-ai-logo.svg -icon_large: - en_US: x-ai-logo.svg -help: - title: - en_US: Get your token from xAI - zh_Hans: 从 xAI 获取 token - url: - en_US: https://x.ai/api -supported_model_types: - - llm -configurate_methods: - - predefined-model -provider_credential_schema: - credential_form_schemas: - - variable: api_key - label: - en_US: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入您的 API Key - en_US: Enter your API Key - - variable: api_base - label: - en_US: API Base - type: text-input - required: false - default: https://api.x.ai - placeholder: - zh_Hans: 在此输入您的 API Base - en_US: Enter your API Base