mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
refactor
This commit is contained in:
parent
91cb80f795
commit
435e71eb60
|
@ -3,7 +3,7 @@ from typing import Generic, Optional, TypeVar
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | bool))
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool))
|
||||
|
||||
|
||||
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
|
||||
|
|
|
@ -12,7 +12,7 @@ from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
|||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | bool))
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool))
|
||||
|
||||
|
||||
class BasePluginManager:
|
||||
|
@ -22,6 +22,7 @@ class BasePluginManager:
|
|||
path: str,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
stream: bool = False,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
|
@ -30,16 +31,23 @@ class BasePluginManager:
|
|||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
response = requests.request(method=method, url=str(url), headers=headers, data=data, stream=stream)
|
||||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream
|
||||
)
|
||||
return response
|
||||
|
||||
def _stream_request(
|
||||
self, method: str, path: str, headers: dict | None = None, data: bytes | dict | None = None
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
"""
|
||||
response = self._request(method, path, headers, data, stream=True)
|
||||
response = self._request(method, path, headers, data, params, stream=True)
|
||||
yield from response.iter_lines()
|
||||
|
||||
def _stream_request_with_model(
|
||||
|
@ -49,29 +57,42 @@ class BasePluginManager:
|
|||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, headers, data):
|
||||
for line in self._stream_request(method, path, params, headers, data):
|
||||
yield type(**json.loads(line))
|
||||
|
||||
def _request_with_model(
|
||||
self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | None = None
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data)
|
||||
response = self._request(method, path, headers, data, params)
|
||||
return type(**response.json())
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data)
|
||||
response = self._request(method, path, headers, data, params)
|
||||
rep = PluginDaemonBasicResponse[type](**response.json())
|
||||
if rep.code != 0:
|
||||
raise ValueError(f"got error from plugin daemon: {rep.message}, code: {rep.code}")
|
||||
|
@ -81,12 +102,18 @@ class BasePluginManager:
|
|||
return rep.data
|
||||
|
||||
def _request_with_plugin_daemon_response_stream(
|
||||
self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, headers, data):
|
||||
for line in self._stream_request(method, path, params, headers, data):
|
||||
line_data = json.loads(line)
|
||||
rep = PluginDaemonBasicResponse[type](**line_data)
|
||||
if rep.code != 0:
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginModelManager(BasePluginManager):
|
||||
pass
|
||||
def fetch_model_providers(self, tenant_id: str) -> list[ProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET", f"plugin/{tenant_id}/models", list[ProviderEntity], params={"page": 1, "page_size": 256}
|
||||
)
|
||||
return response
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import Generator
|
||||
from urllib.parse import quote
|
||||
|
||||
from core.plugin.entities.plugin_daemon import InstallPluginMessage
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
@ -9,9 +8,8 @@ class PluginInstallationManager(BasePluginManager):
|
|||
def fetch_plugin_by_identifier(self, tenant_id: str, identifier: str) -> bool:
|
||||
# urlencode the identifier
|
||||
|
||||
identifier = quote(identifier)
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET", f"/plugin/{tenant_id}/fetch/identifier?plugin_unique_identifier={identifier}", bool
|
||||
"GET", f"plugin/{tenant_id}/fetch/identifier", bool, params={"plugin_unique_identifier": identifier}
|
||||
)
|
||||
|
||||
def install_from_pkg(self, tenant_id: str, pkg: bytes) -> Generator[InstallPluginMessage, None, None]:
|
||||
|
@ -22,21 +20,20 @@ class PluginInstallationManager(BasePluginManager):
|
|||
body = {"dify_pkg": ("dify_pkg", pkg, "application/octet-stream")}
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST", f"/plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body
|
||||
"POST", f"plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body
|
||||
)
|
||||
|
||||
def install_from_identifier(self, tenant_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Install a plugin from an identifier.
|
||||
"""
|
||||
identifier = quote(identifier)
|
||||
# exception will be raised if the request failed
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"/plugin/{tenant_id}/install/identifier",
|
||||
f"plugin/{tenant_id}/install/identifier",
|
||||
bool,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
params={
|
||||
"plugin_unique_identifier": identifier,
|
||||
},
|
||||
data={
|
||||
"plugin_unique_identifier": identifier,
|
||||
|
@ -48,5 +45,5 @@ class PluginInstallationManager(BasePluginManager):
|
|||
Uninstall a plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"DELETE", f"/plugin/{tenant_id}/uninstall?plugin_unique_identifier={identifier}", bool
|
||||
"DELETE", f"plugin/{tenant_id}/uninstall", bool, params={"plugin_unique_identifier": identifier}
|
||||
)
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginManager):
|
||||
def fetch_tool_providers(self, asset_id: str) -> list[str]:
|
||||
def fetch_tool_providers(self, tenant_id: str) -> list[ToolProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given asset.
|
||||
"""
|
||||
response = self._request('GET', f'/plugin/asset/{asset_id}')
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET", f"plugin/{tenant_id}/tools", list[ToolProviderEntity], params={"page": 1, "page_size": 256}
|
||||
)
|
||||
return response
|
||||
|
|
|
@ -274,9 +274,12 @@ class ToolProviderIdentity(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
|
@ -284,12 +287,24 @@ class ToolDescription(BaseModel):
|
|||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
|
@ -352,15 +367,4 @@ class ToolInvokeFrom(Enum):
|
|||
AGENT = "agent"
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
|
|
@ -83,3 +83,8 @@ VOLC_EMBEDDING_ENDPOINT_ID=
|
|||
|
||||
# 360 AI Credentials
|
||||
ZHINAO_API_KEY=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_API_KEY=
|
||||
PLUGIN_API_URL=
|
||||
INNER_API_KEY=
|
66
api/tests/integration_tests/plugin/__mock/http.py
Normal file
66
api/tests/integration_tests/plugin/__mock/http.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import os
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity, ToolProviderIdentity
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@classmethod
|
||||
def list_tools(cls) -> list[ToolProviderEntity]:
|
||||
return [
|
||||
ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author="Yeuoly",
|
||||
name="Yeuoly",
|
||||
description=I18nObject(en_US="Yeuoly"),
|
||||
icon="ssss.svg",
|
||||
label=I18nObject(en_US="Yeuoly"),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def requests_request(
|
||||
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Mocked requests.request
|
||||
"""
|
||||
request = requests.PreparedRequest()
|
||||
request.method = method
|
||||
request.url = url
|
||||
if url.endswith("/tools"):
|
||||
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
|
||||
code=0, message="success", data=cls.list_tools()
|
||||
).model_dump_json()
|
||||
else:
|
||||
raise ValueError("")
|
||||
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response.request = request
|
||||
response._content = content.encode("utf-8")
|
||||
return response
|
||||
|
||||
|
||||
MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK_SWITCH:
|
||||
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
yield
|
||||
|
||||
if MOCK_SWITCH:
|
||||
unpatch()
|
|
@ -0,0 +1,9 @@
|
|||
from core.plugin.manager.tool import PluginToolManager
|
||||
from tests.integration_tests.plugin.__mock.http import setup_http_mock
|
||||
|
||||
|
||||
def test_fetch_all_plugin_tools(setup_http_mock):
|
||||
manager = PluginToolManager()
|
||||
tools = manager.fetch_tool_providers(tenant_id="test-tenant")
|
||||
assert len(tools) >= 1
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_generator = ToolManager.list_builtin_providers()
|
||||
provider_names = [provider.identity.name for provider in provider_generator]
|
||||
ToolManager.clear_builtin_providers_cache()
|
||||
provider_generator = ToolManager.list_builtin_providers()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", provider_names)
|
||||
def test_tool_providers(benchmark, name):
|
||||
"""
|
||||
Test that all tool providers can be loaded
|
||||
"""
|
||||
|
||||
def test(generator):
|
||||
try:
|
||||
return next(generator)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
|
Loading…
Reference in New Issue
Block a user