feat: support spark v2 validate (#1086)

This commit is contained in:
takatost 2023-09-01 20:53:32 +08:00 committed by GitHub
parent 73c86ee6a0
commit a7cdb745c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -83,14 +83,32 @@ class SparkProvider(BaseModelProvider):
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
try:
chat_llm = ChatSpark(
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@ -108,6 +126,9 @@ class SparkProvider(BaseModelProvider):
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: