diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 2443991d57..1c7cb39c92 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -35,7 +35,8 @@ class VannaTool(BuiltinTool): password = tool_parameters.get("password", "") port = tool_parameters.get("port", 0) - vn = VannaDefault(model=model, api_key=api_key) + base_url = self.runtime.credentials.get("base_url", None) + vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url}) db_type = tool_parameters.get("db_type", "") if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index 84724e921a..1d71414bf3 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -1,4 +1,6 @@ +import re from typing import Any +from urllib.parse import urlparse from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vanna.tools.vanna import VannaTool @@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class VannaProvider(BuiltinToolProviderController): + def _get_protocol_and_main_domain(self, url): + parsed_url = urlparse(url) + protocol = parsed_url.scheme + hostname = parsed_url.hostname + port = f":{parsed_url.port}" if parsed_url.port else "" + + # Check if the hostname is an IP address + is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None + + # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain + main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port + return f"{protocol}://{main_domain}" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + base_url = credentials.get("base_url") + if not base_url: + base_url = "https://ask.vanna.ai/rpc" + else: + base_url = base_url.removesuffix("/") + credentials["base_url"] = base_url try: VannaTool().fork_tool_runtime( runtime={ @@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController): tool_parameters={ "model": "chinook", "db_type": "SQLite", - "url": "https://vanna.ai/Chinook.sqlite", + "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', "query": "What are the top 10 customers by sales?", }, ) diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml index 7f953be172..cf3fdca562 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -26,3 +26,10 @@ credentials_for_provider: en_US: Get your API key from Vanna.AI zh_Hans: 从 Vanna.AI 获取你的 API key url: https://vanna.ai/account/profile + base_url: + type: text-input + required: false + label: + en_US: Vanna.AI Endpoint Base URL + placeholder: + en_US: https://ask.vanna.ai/rpc