refactor(api): improve handling of tools field and cleanup variable usage (#10553)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions

This commit is contained in:
-LAN- 2024-11-12 00:08:04 +08:00 committed by GitHub
parent b7238caea5
commit 16b9665033
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 10 deletions

View File

@ -1,6 +1,6 @@
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -32,9 +32,14 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None original_credentials: Optional[dict] = None
is_team_authorization: bool = False is_team_authorization: bool = False
allow_delete: bool = True allow_delete: bool = True
tools: list[UserTool] | None = None tools: list[UserTool] = Field(default_factory=list)
labels: list[str] | None = None labels: list[str] | None = None
@field_validator("tools", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict: def to_dict(self) -> dict:
# ------------- # -------------
# overwrite tool parameter types for temp fix # overwrite tool parameter types for temp fix

View File

@ -116,7 +116,7 @@ class ApiToolManageService:
provider_name = provider_name.strip() provider_name = provider_name.strip()
# check if the provider exists # check if the provider exists
provider: ApiToolProvider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -201,16 +201,15 @@ class ApiToolManageService:
return {"schema": schema} return {"schema": schema}
@staticmethod @staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
""" """
list api tool provider tools list api tool provider tools
""" """
provider_name = provider provider = (
provider: ApiToolProvider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider_name,
) )
.first() .first()
) )
@ -252,7 +251,7 @@ class ApiToolManageService:
provider_name = provider_name.strip() provider_name = provider_name.strip()
# check if the provider exists # check if the provider exists
provider: ApiToolProvider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -319,7 +318,7 @@ class ApiToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider: ApiToolProvider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -369,7 +368,7 @@ class ApiToolManageService:
if tool_bundle is None: if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}") raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = ( db_provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,