fixed the Base URL usage issue in Podcast Generator tool verification (#10697)

This commit is contained in:
Xiao Ley 2024-11-14 17:24:42 +08:00 committed by GitHub
parent 15f341b655
commit fbb9c1c249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
from typing import Any from typing import Any
import openai import openai
from yarl import URL
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -10,6 +11,7 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
tts_service = credentials.get("tts_service") tts_service = credentials.get("tts_service")
api_key = credentials.get("api_key") api_key = credentials.get("api_key")
base_url = credentials.get("openai_base_url")
if not tts_service: if not tts_service:
raise ToolProviderCredentialValidationError("TTS service is not specified") raise ToolProviderCredentialValidationError("TTS service is not specified")
@ -17,13 +19,16 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
if not api_key: if not api_key:
raise ToolProviderCredentialValidationError("API key is missing") raise ToolProviderCredentialValidationError("API key is missing")
if base_url:
base_url = str(URL(base_url) / "v1")
if tts_service == "openai": if tts_service == "openai":
self._validate_openai_credentials(api_key) self._validate_openai_credentials(api_key, base_url)
else: else:
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}") raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
def _validate_openai_credentials(self, api_key: str) -> None: def _validate_openai_credentials(self, api_key: str, base_url: str | None) -> None:
client = openai.OpenAI(api_key=api_key) client = openai.OpenAI(api_key=api_key, base_url=base_url)
try: try:
# We're using a simple API call to validate the credentials # We're using a simple API call to validate the credentials
client.models.list() client.models.list()