mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
fixed the Base URL usage issue in Podcast Generator tool verification (#10697)
This commit is contained in:
parent
15f341b655
commit
fbb9c1c249
|
@ -1,6 +1,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
@ -10,6 +11,7 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
tts_service = credentials.get("tts_service")
|
tts_service = credentials.get("tts_service")
|
||||||
api_key = credentials.get("api_key")
|
api_key = credentials.get("api_key")
|
||||||
|
base_url = credentials.get("openai_base_url")
|
||||||
|
|
||||||
if not tts_service:
|
if not tts_service:
|
||||||
raise ToolProviderCredentialValidationError("TTS service is not specified")
|
raise ToolProviderCredentialValidationError("TTS service is not specified")
|
||||||
|
@ -17,13 +19,16 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ToolProviderCredentialValidationError("API key is missing")
|
raise ToolProviderCredentialValidationError("API key is missing")
|
||||||
|
|
||||||
|
if base_url:
|
||||||
|
base_url = str(URL(base_url) / "v1")
|
||||||
|
|
||||||
if tts_service == "openai":
|
if tts_service == "openai":
|
||||||
self._validate_openai_credentials(api_key)
|
self._validate_openai_credentials(api_key, base_url)
|
||||||
else:
|
else:
|
||||||
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
|
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
|
||||||
|
|
||||||
def _validate_openai_credentials(self, api_key: str) -> None:
|
def _validate_openai_credentials(self, api_key: str, base_url: str | None) -> None:
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
||||||
try:
|
try:
|
||||||
# We're using a simple API call to validate the credentials
|
# We're using a simple API call to validate the credentials
|
||||||
client.models.list()
|
client.models.list()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user