mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
refactor: list tools
This commit is contained in:
parent
435e71eb60
commit
7a3e756020
|
@ -118,7 +118,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider, tenant_id)
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
@ -290,7 +292,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
|
|
|
@ -166,7 +166,7 @@ class BaseAgentRunner(AppRunner):
|
|||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
|
|
@ -3,6 +3,8 @@ from typing import Generic, Optional, TypeVar
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool))
|
||||
|
||||
|
||||
|
@ -27,3 +29,10 @@ class InstallPluginMessage(BaseModel):
|
|||
|
||||
event: Event
|
||||
data: str
|
||||
|
||||
|
||||
class PluginToolProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: ToolProviderEntityWithPlugin
|
|
@ -93,7 +93,14 @@ class BasePluginManager:
|
|||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params)
|
||||
rep = PluginDaemonBasicResponse[type](**response.json())
|
||||
json_response = response.json()
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
|
||||
rep = PluginDaemonBasicResponse[type](**json_response)
|
||||
if rep.code != 0:
|
||||
raise ValueError(f"got error from plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
@ -1,13 +1,65 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginManager):
|
||||
def fetch_tool_providers(self, tenant_id: str) -> list[ToolProviderEntity]:
|
||||
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given asset.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET", f"plugin/{tenant_id}/tools", list[ToolProviderEntity], params={"page": 1, "page_size": 256}
|
||||
"GET", f"plugin/{tenant_id}/tools", list[PluginToolProviderEntity], params={"page": 1, "page_size": 256}
|
||||
)
|
||||
return response
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tool_provider: str,
|
||||
tool_name: str,
|
||||
credentials: dict[str, Any],
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/tool/invoke",
|
||||
ToolInvokeMessage,
|
||||
data={
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider,
|
||||
"tool": tool_name,
|
||||
"credentials": credentials,
|
||||
"tool_parameters": tool_parameters,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/tool/validate_credentials",
|
||||
bool,
|
||||
data={
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
|
|
@ -105,11 +105,11 @@ class Tool(ABC):
|
|||
"""
|
||||
return self.entity.parameters
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_merged_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get all runtime parameters
|
||||
get merged runtime parameters
|
||||
|
||||
:return: all runtime parameters
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
|
|
|
@ -12,11 +12,9 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
tools: list[Tool]
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
self.tools = []
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
|
|
|
@ -16,7 +16,7 @@ class ToolRuntime(BaseModel):
|
|||
tool_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: Optional[dict[str, Any]] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.provider_type == ToolProviderType.API:
|
||||
super().__init__(**data)
|
||||
return
|
||||
self.tools = []
|
||||
|
||||
# load provider yaml
|
||||
provider = self.__class__.__module__.split(".")[-1]
|
||||
|
@ -76,9 +74,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
parent_type=BuiltinTool,
|
||||
)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(
|
||||
entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""),
|
||||
))
|
||||
tools.append(
|
||||
assistant_tool_class(
|
||||
entity=ToolEntity(**tool),
|
||||
runtime=ToolRuntime(tenant_id=""),
|
||||
)
|
||||
)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
@ -142,7 +143,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
|
@ -153,10 +154,10 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(credentials)
|
||||
self._validate_credentials(user_id, credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
|
||||
if not cls._position:
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
def name_func(provider: ToolProviderApiEntity) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
|
|
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
|
|
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|||
|
||||
|
||||
class QRCodeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
|
|
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
|
|
@ -24,9 +24,10 @@ class ApiToolProviderController(ToolProviderController):
|
|||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.tools = []
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
|
||||
credentials_schema = {
|
||||
"auth_type": ProviderConfig(
|
||||
name="auth_type",
|
||||
|
|
|
@ -9,7 +9,7 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class UserTool(BaseModel):
|
||||
class ToolApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
|
@ -18,10 +18,10 @@ class UserTool(BaseModel):
|
|||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
|
@ -33,7 +33,7 @@ class UserToolProvider(BaseModel):
|
|||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = Field(default_factory=list)
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
@ -63,5 +63,5 @@ class UserToolProvider(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
class ToolProviderCredentialsApiEntity(BaseModel):
|
||||
credentials: dict[str, ProviderConfig]
|
||||
|
|
|
@ -224,6 +224,13 @@ class ToolParameter(BaseModel):
|
|||
max: Optional[Union[float, int]] = None
|
||||
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
|
@ -304,6 +311,9 @@ class ToolEntity(BaseModel):
|
|||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class PluginToolProvider(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
return super().get_tool(tool_name)
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
get credentials schema
|
||||
"""
|
||||
return super().get_credentials_schema()
|
||||
|
72
api/core/tools/plugin_tool/provider.py
Normal file
72
api/core/tools/plugin_tool/provider.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from typing import Any
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
class PluginToolProviderController(BuiltinToolProviderController):
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_unique_identifier: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
41
api/core/tools/plugin_tool/tool.py
Normal file
41
api/core/tools/plugin_tool/tool.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_unique_identifier: str) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
manager = PluginToolManager()
|
||||
return manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
|
@ -6,7 +6,10 @@ from os import listdir, path
|
|||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
@ -24,7 +27,7 @@ from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
|||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
|
@ -41,38 +44,61 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
_hardcoded_providers = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
def get_builtin_provider(
|
||||
cls, provider: str, tenant_id: str
|
||||
) -> BuiltinToolProviderController | PluginToolProviderController:
|
||||
"""
|
||||
get the builtin provider
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider
|
||||
"""
|
||||
if len(cls._builtin_providers) == 0:
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_builtin_providers_cache()
|
||||
cls.load_hardcoded_providers_cache()
|
||||
|
||||
if provider not in cls._builtin_providers:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
|
||||
if provider not in cls._hardcoded_providers:
|
||||
# get plugin provider
|
||||
plugin_provider = cls.get_plugin_provider(provider, tenant_id)
|
||||
if plugin_provider:
|
||||
return plugin_provider
|
||||
|
||||
return cls._builtin_providers[provider]
|
||||
return cls._hardcoded_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
|
||||
def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
|
||||
"""
|
||||
get the plugin provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
providers = manager.fetch_tool_providers(tenant_id)
|
||||
provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None)
|
||||
if not provider_entity:
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
|
||||
return PluginToolProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
|
||||
return tool
|
||||
|
@ -97,12 +123,12 @@ class ToolManager:
|
|||
:return: the tool
|
||||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name, tenant_id)
|
||||
if not builtin_tool:
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id)
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
|
@ -131,7 +157,7 @@ class ToolManager:
|
|||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = cls.get_builtin_provider(provider_id)
|
||||
controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
|
@ -246,7 +272,7 @@ class ToolManager:
|
|||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
|
@ -294,7 +320,7 @@ class ToolManager:
|
|||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
|
@ -321,16 +347,17 @@ class ToolManager:
|
|||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
|
||||
def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
get the absolute path of the icon of the builtin provider
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
|
||||
:return: the absolute path of the icon, the mime type of the icon
|
||||
"""
|
||||
# get provider
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
|
||||
absolute_path = path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
|
@ -351,21 +378,48 @@ class ToolManager:
|
|||
return absolute_path, mime_type
|
||||
|
||||
@classmethod
|
||||
def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
def list_hardcoded_providers(cls):
|
||||
# use cache first
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._builtin_providers.values())
|
||||
yield from list(cls._hardcoded_providers.values())
|
||||
return
|
||||
|
||||
with cls._builtin_provider_lock:
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._builtin_providers.values())
|
||||
yield from list(cls._hardcoded_providers.values())
|
||||
return
|
||||
|
||||
yield from cls._list_builtin_providers()
|
||||
yield from cls._list_hardcoded_providers()
|
||||
|
||||
@classmethod
|
||||
def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
|
||||
"""
|
||||
list all the plugin providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||
return [
|
||||
PluginToolProviderController(
|
||||
entity=provider.declaration,
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
)
|
||||
for provider in provider_entities
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_builtin_providers(
|
||||
cls, tenant_id: str
|
||||
) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
yield from cls.list_hardcoded_providers()
|
||||
# get plugin providers
|
||||
yield from cls.list_plugin_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
|
@ -391,7 +445,7 @@ class ToolManager:
|
|||
parent_type=BuiltinToolProviderController,
|
||||
)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.entity.identity.name] = provider
|
||||
cls._hardcoded_providers[provider.entity.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
|
||||
yield provider
|
||||
|
@ -403,13 +457,13 @@ class ToolManager:
|
|||
cls._builtin_providers_loaded = True
|
||||
|
||||
@classmethod
|
||||
def load_builtin_providers_cache(cls):
|
||||
for _ in cls.list_builtin_providers():
|
||||
def load_hardcoded_providers_cache(cls):
|
||||
for _ in cls.list_hardcoded_providers():
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def clear_builtin_providers_cache(cls):
|
||||
cls._builtin_providers = {}
|
||||
def clear_hardcoded_providers_cache(cls):
|
||||
cls._hardcoded_providers = {}
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
|
@ -423,7 +477,7 @@ class ToolManager:
|
|||
"""
|
||||
if len(cls._builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_builtin_providers_cache()
|
||||
cls.load_hardcoded_providers_cache()
|
||||
|
||||
if tool_name not in cls._builtin_tools_labels:
|
||||
return None
|
||||
|
@ -432,9 +486,9 @@ class ToolManager:
|
|||
|
||||
@classmethod
|
||||
def user_list_providers(
|
||||
cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral
|
||||
) -> list[UserToolProvider]:
|
||||
result_providers: dict[str, UserToolProvider] = {}
|
||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
result_providers: dict[str, ToolProviderApiEntity] = {}
|
||||
|
||||
filters = []
|
||||
if not typ:
|
||||
|
@ -444,7 +498,7 @@ class ToolManager:
|
|||
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
|
@ -666,4 +720,4 @@ class ToolManager:
|
|||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
|
|
@ -167,7 +167,7 @@ class WorkflowTool(Tool):
|
|||
:param tool_parameters: the tool parameters
|
||||
:return: tool_parameters, files
|
||||
"""
|
||||
parameter_rules = self.get_all_runtime_parameters()
|
||||
parameter_rules = self.get_merged_runtime_parameters()
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
|
|
|
@ -7,7 +7,7 @@ from core.entities.provider_entities import ProviderConfig
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
|
@ -201,7 +201,7 @@ class ApiToolManageService:
|
|||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
|
@ -438,7 +438,7 @@ class ApiToolManageService:
|
|||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
|
@ -447,7 +447,7 @@ class ApiToolManageService:
|
|||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
|
|
|
@ -5,9 +5,8 @@ from pathlib import Path
|
|||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
@ -21,11 +20,17 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
|
||||
:param user_id: the id of the user
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider: the name of the provider
|
||||
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
|
@ -64,14 +69,16 @@ class BuiltinToolManageService:
|
|||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()])
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
|
@ -90,7 +97,7 @@ class BuiltinToolManageService:
|
|||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
|
@ -109,7 +116,7 @@ class BuiltinToolManageService:
|
|||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# validate credentials
|
||||
provider_controller.validate_credentials(credentials)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
# encrypt credentials
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
||||
|
@ -154,7 +161,7 @@ class BuiltinToolManageService:
|
|||
if provider_obj is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
|
@ -186,7 +193,7 @@ class BuiltinToolManageService:
|
|||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
|
@ -198,22 +205,22 @@ class BuiltinToolManageService:
|
|||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
def get_builtin_tool_provider_icon(provider: str, tenant_id: str):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id)
|
||||
icon_bytes = Path(icon_path).read_bytes()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
|
@ -225,7 +232,7 @@ class BuiltinToolManageService:
|
|||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from core.tools.__base.tool import Tool
|
|||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
|
@ -15,6 +15,7 @@ from core.tools.entities.tool_entities import (
|
|||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
@ -44,7 +45,7 @@ class ToolTransformService:
|
|||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
def repack_provider(provider: Union[dict, ToolProviderApiEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
|
@ -54,7 +55,7 @@ class ToolTransformService:
|
|||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
)
|
||||
|
@ -62,14 +63,14 @@ class ToolTransformService:
|
|||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
|
@ -154,7 +155,7 @@ class ToolTransformService:
|
|||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
return UserToolProvider(
|
||||
return ToolProviderApiEntity(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
|
@ -181,7 +182,7 @@ class ToolTransformService:
|
|||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
|
@ -197,7 +198,7 @@ class ToolTransformService:
|
|||
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
|
@ -240,7 +241,7 @@ class ToolTransformService:
|
|||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserTool:
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
|
@ -248,7 +249,7 @@ class ToolTransformService:
|
|||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
credentials=credentials or {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
@ -270,7 +271,7 @@ class ToolTransformService:
|
|||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
|
@ -279,7 +280,7 @@ class ToolTransformService:
|
|||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
@ -183,7 +183,7 @@ class WorkflowToolManageService:
|
|||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
|
@ -309,7 +309,7 @@ class WorkflowToolManageService:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
|
|
Loading…
Reference in New Issue
Block a user