diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 860ec5de0c..a703c31de9 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -173,7 +173,7 @@ class BaseAgentRunner(AppRunner): }: continue enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: + if parameter.type in {ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.MULTI_SELECT}: enum = [option.value for option in parameter.options] message_tool.parameters["properties"][parameter.name] = { @@ -264,7 +264,7 @@ class BaseAgentRunner(AppRunner): }: continue enum = [] - if parameter.type == ToolParameter.ToolParameterType.SELECT: + if parameter.type in {ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.MULTI_SELECT}: enum = [option.value for option in parameter.options] prompt_tool.parameters["properties"][parameter.name] = { diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d8637fd2cb..07e9dc6248 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -142,12 +142,13 @@ class ToolParameter(BaseModel): NUMBER = "number" BOOLEAN = "boolean" SELECT = "select" + MULTI_SELECT = "multi-select" SECRET_INPUT = "secret-input" FILE = "file" FILES = "files" # deprecated, should not use. - SYSTEM_FILES = "systme-files" + SYSTEM_FILES = "system-files" def as_normal_type(self): if self in { @@ -155,6 +156,8 @@ class ToolParameter(BaseModel): ToolParameter.ToolParameterType.SELECT, }: return "string" + elif self == ToolParameter.ToolParameterType.MULTI_SELECT: + return "array" return self.value def cast_value(self, value: Any, /): @@ -198,6 +201,7 @@ class ToolParameter(BaseModel): ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILE | ToolParameter.ToolParameterType.FILES + | ToolParameter.ToolParameterType.MULTI_SELECT ): return value case _: diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 15ab510c6c..18b9671a34 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -29,7 +29,7 @@ class CrawlTool(BuiltinTool): payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) payload["webhook"] = tool_parameters.get("webhook") - scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["formats"] = tool_parameters.get("formats") scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0d7dbcac20..0d51f0fed4 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -123,18 +123,35 @@ parameters: form: form ############## Scrape Options ####################### - name: formats - type: string + type: multi-select label: - en_US: Formats + en_US: Output Formats zh_Hans: 结果的格式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 + options: + - value: markdown + label: + en_US: markdown + - value: html + label: + en_US: html + - value: rawHtml + label: + en_US: rawHtml + - value: links + label: + en_US: links + - value: screenshot + label: + en_US: screenshot + - value: extract + label: + en_US: extract + - value: screenshot@fullPage + label: + en_US: screenshot@fullPage human_description: - en_US: | - Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot - zh_Hans: | - 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + en_US: Formats to include in the output. Multiple selections possible. + zh_Hans: 输出中应包含的格式。可多选。 form: form - name: headers type: string diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index f00a9b31ce..54e0e62299 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -18,7 +18,7 @@ class ScrapeTool(BuiltinTool): payload = {} extract = {} - payload["formats"] = get_array_params(tool_parameters, "formats") + payload["formats"] = tool_parameters.get("formats") payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) payload["includeTags"] = get_array_params(tool_parameters, "includeTags") payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 8f1f1348a4..472cf6eae8 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -23,18 +23,35 @@ parameters: form: llm ############## Payload ####################### - name: formats - type: string + type: multi-select label: - en_US: Formats + en_US: Output Formats zh_Hans: 结果的格式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 + options: + - value: markdown + label: + en_US: markdown + - value: html + label: + en_US: html + - value: rawHtml + label: + en_US: rawHtml + - value: links + label: + en_US: links + - value: screenshot + label: + en_US: screenshot + - value: extract + label: + en_US: extract + - value: screenshot@fullPage + label: + en_US: screenshot@fullPage human_description: - en_US: | - Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage - zh_Hans: | - 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + en_US: Formats to include in the output. Multiple selections possible. + zh_Hans: 输出中应包含的格式。可多选。 form: form - name: onlyMainContent type: boolean diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3b..93cedd8716 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -7,7 +7,6 @@ from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredent from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolNotFoundError, - ToolParameterValidationError, ToolProviderNotFoundError, ) from core.tools.provider.tool_provider import ToolProviderController @@ -146,70 +145,6 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.identity.tags or [] - def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: - """ - validate the parameters of the tool and set the default value if needed - - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool - """ - tool_parameters_schema = self.get_parameters(tool_name) - - tool_parameters_need_to_validate: dict[str, ToolParameter] = {} - for parameter in tool_parameters_schema: - tool_parameters_need_to_validate[parameter.name] = parameter - - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") - - # check type - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") - - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") - - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" - ) - - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" - ) - - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") - - elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") - - options = parameter_schema.options - if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") - - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") - - tool_parameters_need_to_validate.pop(parameter) - - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] - if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") - - # the parameter is not set currently, set the default value if needed - if parameter_schema.default is not None: - default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value - def validate_credentials(self, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a11562..5253adcdd9 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -117,6 +117,18 @@ class ToolProviderController(BaseModel, ABC): if tool_parameters[parameter] not in [x.value for x in options]: raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + elif parameter_schema.type == ToolParameter.ToolParameterType.MULTI_SELECT: + if not isinstance(tool_parameters[parameter], list): + raise ToolParameterValidationError(f"parameter {parameter} should be list") + + options = parameter_schema.options + if not isinstance(options, list): + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + + for item in tool_parameters[parameter]: + if item not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ac333162b6..71fd174c4f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -220,6 +220,11 @@ class ToolManager: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" ) + elif parameter_rule.type == ToolParameter.ToolParameterType.MULTI_SELECT and parameter_value is not None: + options = [x.value for x in parameter_rule.options] + for value in parameter_value: + if value not in options: + raise ValueError(f"tool parameter {parameter_rule.name} value {value} not in options {options}") return parameter_rule.type.cast_value(parameter_value) @@ -228,7 +233,7 @@ class ToolManager: cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER ) -> Tool: """ - get the agent tool runtime + get the agent tool runtim """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py index 8a41678267..067bc52c26 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py @@ -5,6 +5,7 @@ def test_get_parameter_type(): assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string" assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string" assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.MULTI_SELECT.as_normal_type() == "array" assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean" assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number" assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file" @@ -30,6 +31,10 @@ def test_cast_parameter_by_type(): assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0" assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == "" + # multi select + assert ToolParameter.ToolParameterType.MULTI_SELECT.cast_value(["test", "test2"]) == ["test", "test2"] + assert ToolParameter.ToolParameterType.MULTI_SELECT.cast_value([2, 1, 1.0]) == [2, 1, 1.0] + # boolean true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index c70cf24661..cf36c60c89 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -29,7 +29,7 @@ export type Item = { export type ISelectProps = { className?: string wrapperClassName?: string - renderTrigger?: (value: Item | null) => JSX.Element | null + renderTrigger?: (value: Item | Item[] | null) => JSX.Element | null items?: Item[] defaultValue?: number | string disabled?: boolean @@ -378,5 +378,127 @@ const PortalSelect: FC = ({ ) } -export { SimpleSelect, PortalSelect } + +export type MultiSelectProps = Omit & { + onSelect: (values: Item[]) => void + selectedValues?: (number | string)[] +} + +const MultiSelect: FC = ({ + className, + wrapperClassName = '', + renderTrigger, + items = defaultItems, + selectedValues = [], + disabled = false, + onSelect, + placeholder, + optionWrapClassName, + optionClassName, + hideChecked, + renderOption, +}) => { + const { t } = useTranslation() + const localPlaceholder = placeholder || t('common.placeholder.select') + + const [selectedItems, setSelectedItems] = useState([]) + useEffect(() => { + const defaultSelected = items.filter((item: Item) => selectedValues.includes(item.value)) + setSelectedItems(defaultSelected) + }, [selectedValues, items]) + + const handleSelect = (newItems: Item[]) => { + if (!disabled) { + setSelectedItems(newItems) + onSelect(newItems) + } + } + + const removeItem = (item: Item) => { + const newSelectedItems = selectedItems.filter(i => i.value !== item.value) + setSelectedItems(newSelectedItems) + onSelect(newSelectedItems) + } + + return ( + +
+ {renderTrigger && {renderTrigger(selectedItems)}} + {!renderTrigger && ( + + {selectedItems.length > 0 + ? (selectedItems.map(item => ( + + {item.name} + { + e.stopPropagation() + removeItem(item) + }} + /> + + )) + ) + : ( + {localPlaceholder} + )} + + + + )} + + {!disabled && ( + + + {items.map((item: Item) => ( + + classNames( + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : ''}`, + optionClassName, + ) + } + value={item} + > + {({ selected }) => ( + <> + {renderOption + ? renderOption({ item, selected }) + : ( + <> + {item.name} + {selected && !hideChecked && ( + + + )} + + )} + + )} + + ))} + + + )} +
+
+ ) +} +export { SimpleSelect, PortalSelect, MultiSelect } export default React.memo(Select) diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 8a84376bea..9342182dd4 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -11,6 +11,7 @@ export enum FormTypeEnum { textNumber = 'number-input', secretInput = 'secret-input', select = 'select', + multiSelect = 'multi-select', radio = 'radio', boolean = 'boolean', files = 'files', diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index c0a7be68a6..d2a8c4e11c 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -14,7 +14,7 @@ import { FormTypeEnum } from '../declarations' import { useLanguage } from '../hooks' import Input from './Input' import cn from '@/utils/classnames' -import { SimpleSelect } from '@/app/components/base/select' +import { MultiSelect, SimpleSelect } from '@/app/components/base/select' import Tooltip from '@/app/components/base/tooltip' import Radio from '@/app/components/base/radio' type FormProps = { @@ -53,7 +53,7 @@ const Form: FC = ({ const language = useLanguage() const [changeKey, setChangeKey] = useState('') - const handleFormChange = (key: string, val: string | boolean) => { + const handleFormChange = (key: string, val: string | boolean | any[]) => { if (isEditMode && (key === '__model_type' || key === '__model_name')) return @@ -220,6 +220,44 @@ const Form: FC = ({ ) } + if (formSchema.type === 'multi-select') { + const { + options, + variable, + label, + show_on, + required, + placeholder, + } = formSchema as CredentialFormSchemaSelect + + if (show_on.length && !show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) + return null + + const handleMultiSelect = (selectedItems: any[]) => { + handleFormChange(variable, selectedItems.map(item => item.value)) + } + + return ( +
+
+ {label[language] || label.en_US} + {required && *} + {tooltipContent} +
+ ({ value: item.value, name: item.label[language] || item.label.en_US }))} + onSelect={handleMultiSelect} + placeholder={placeholder?.[language] || placeholder?.en_US} + /> + {fieldMoreInfo?.(formSchema)} + {validating && changeKey === variable && } +
+ ) + } + if (formSchema.type === 'boolean') { const { variable, diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index e47082f4b7..fbf5d9a3f4 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -50,6 +50,8 @@ const InputVarList: FC = ({ return 'Files' else if (type === FormTypeEnum.select) return 'Options' + else if (type === FormTypeEnum.multiSelect) + return 'Array' else return 'String' } @@ -139,9 +141,10 @@ const InputVarList: FC = ({ const varInput = value[variable] const isNumber = type === FormTypeEnum.textNumber const isSelect = type === FormTypeEnum.select + const isMultiSelect = type === FormTypeEnum.multiSelect const isFile = type === FormTypeEnum.file const isFileArray = type === FormTypeEnum.files - const isString = !isNumber && !isSelect && !isFile && !isFileArray + const isString = !isNumber && !isSelect && !isMultiSelect && !isFile && !isFileArray return (
@@ -177,6 +180,20 @@ const InputVarList: FC = ({ schema={schema} /> )} + { + isMultiSelect && ( + varPayload.type === VarType.arrayString || varPayload.type === VarType.arrayNumber} + /> + ) + } {isFile && (