From fbb9c1c249b8e3a956a708975c2780a8abf0baa6 Mon Sep 17 00:00:00 2001 From: Xiao Ley Date: Thu, 14 Nov 2024 17:24:42 +0800 Subject: [PATCH] fixed the Base URL usage issue in Podcast Generator tool verification (#10697) --- .../builtin/podcast_generator/podcast_generator.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py index 0b9c025834..a7f7ad2e78 100644 --- a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py @@ -1,6 +1,7 @@ from typing import Any import openai +from yarl import URL from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -10,6 +11,7 @@ class PodcastGeneratorProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: tts_service = credentials.get("tts_service") api_key = credentials.get("api_key") + base_url = credentials.get("openai_base_url") if not tts_service: raise ToolProviderCredentialValidationError("TTS service is not specified") @@ -17,13 +19,16 @@ class PodcastGeneratorProvider(BuiltinToolProviderController): if not api_key: raise ToolProviderCredentialValidationError("API key is missing") + if base_url: + base_url = str(URL(base_url) / "v1") + if tts_service == "openai": - self._validate_openai_credentials(api_key) + self._validate_openai_credentials(api_key, base_url) else: raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}") - def _validate_openai_credentials(self, api_key: str) -> None: - client = openai.OpenAI(api_key=api_key) + def _validate_openai_credentials(self, api_key: str, base_url: str | None) -> None: + client = openai.OpenAI(api_key=api_key, base_url=base_url) try: # We're using a simple API call to validate the credentials client.models.list()