diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b1db559441..ddb1481276 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,6 +1,6 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject @@ -32,9 +32,14 @@ class UserToolProvider(BaseModel): original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] | None = None + tools: list[UserTool] = Field(default_factory=list) labels: list[str] | None = None + @field_validator("tools", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + def to_dict(self) -> dict: # ------------- # overwrite tool parameter types for temp fix diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ed0cebf460..b6b0143fac 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -116,7 +116,7 @@ class ApiToolManageService: provider_name = provider_name.strip() # check if the provider exists - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -201,16 +201,15 @@ class ApiToolManageService: return {"schema": schema} @staticmethod - def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]: """ list api tool provider tools """ - provider_name = provider - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, + ApiToolProvider.name == provider_name, ) .first() ) @@ -252,7 +251,7 @@ class ApiToolManageService: provider_name = provider_name.strip() # check if the provider exists - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -319,7 +318,7 @@ class ApiToolManageService: """ delete tool provider """ - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -369,7 +368,7 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider: ApiToolProvider = ( + db_provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id,