mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge branch 'feat/tooltip' into deploy/dev
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
commit
2ba2594718
|
@ -92,7 +92,7 @@ class AppMetaApi(Resource):
|
|||
class AppInfoApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app infomation"""
|
||||
"""Get app information"""
|
||||
return {
|
||||
'name':app_model.name,
|
||||
'description':app_model.description
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
@ -18,6 +18,7 @@ from anthropic.types import (
|
|||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
|
@ -462,7 +463,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# standard import
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
@ -17,6 +17,7 @@ from botocore.exceptions import (
|
|||
ServiceNotInRegionError,
|
||||
UnknownServiceError,
|
||||
)
|
||||
from PIL.Image import Image
|
||||
|
||||
# local import
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
|
@ -381,9 +382,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
try:
|
||||
url = message_content.data
|
||||
image_content = requests.get(url).content
|
||||
if '?' in url:
|
||||
url = url.split('?')[0]
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
@ -12,6 +12,7 @@ import google.generativeai.client as client
|
|||
import requests
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
@ -371,7 +372,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
@ -18,6 +19,7 @@ from anthropic.types import (
|
|||
)
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
|
@ -332,7 +334,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
|
|
|
@ -1,6 +1,25 @@
|
|||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from volcenginesdkarkruntime.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
|
||||
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
|
||||
|
||||
|
||||
class MaaSClient(MaasService):
|
||||
def __init__(self, host: str, region: str):
|
||||
class ArkClientV3:
|
||||
endpoint_id: Optional[str] = None
|
||||
ark: Optional[Ark] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.ark = Ark(*args, **kwargs)
|
||||
self.endpoint_id = None
|
||||
super().__init__(host, region)
|
||||
|
||||
def set_endpoint_id(self, endpoint_id: str):
|
||||
self.endpoint_id = endpoint_id
|
||||
|
||||
@classmethod
|
||||
def from_credential(cls, credentials: dict) -> 'MaaSClient':
|
||||
host = credentials['api_endpoint_host']
|
||||
region = credentials['volc_region']
|
||||
ak = credentials['volc_access_key_id']
|
||||
sk = credentials['volc_secret_access_key']
|
||||
endpoint_id = credentials['endpoint_id']
|
||||
|
||||
client = cls(host, region)
|
||||
client.set_endpoint_id(endpoint_id)
|
||||
client.set_ak(ak)
|
||||
client.set_sk(sk)
|
||||
return client
|
||||
|
||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||
req = {
|
||||
'parameters': params,
|
||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||
**extra_model_kwargs,
|
||||
}
|
||||
if not stream:
|
||||
return super().chat(
|
||||
self.endpoint_id,
|
||||
req,
|
||||
)
|
||||
return super().stream_chat(
|
||||
self.endpoint_id,
|
||||
req,
|
||||
)
|
||||
|
||||
def embeddings(self, texts: list[str]) -> dict:
|
||||
req = {
|
||||
'input': texts
|
||||
}
|
||||
return super().embeddings(self.endpoint_id, req)
|
||||
|
||||
@staticmethod
|
||||
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
|
||||
def is_legacy(credentials: dict) -> bool:
|
||||
if ArkClientV3.is_compatible_with_legacy(credentials):
|
||||
return False
|
||||
sdk_version = credentials.get("sdk_version", "v2")
|
||||
return sdk_version != "v3"
|
||||
|
||||
@staticmethod
|
||||
def is_compatible_with_legacy(credentials: dict) -> bool:
|
||||
sdk_version = credentials.get("sdk_version")
|
||||
endpoint = credentials.get("api_endpoint_host")
|
||||
return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
|
||||
|
||||
@classmethod
|
||||
def from_credentials(cls, credentials):
|
||||
"""Initialize the client using the credentials provided."""
|
||||
args = {
|
||||
"base_url": credentials['api_endpoint_host'],
|
||||
"region": credentials['volc_region'],
|
||||
"ak": credentials['volc_access_key_id'],
|
||||
"sk": credentials['volc_secret_access_key'],
|
||||
}
|
||||
if cls.is_compatible_with_legacy(credentials):
|
||||
args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
client = ArkClientV3(
|
||||
**args
|
||||
)
|
||||
client.endpoint_id = credentials['endpoint_id']
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam:
|
||||
"""Converts a PromptMessage to a ChatCompletionMessageParam"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": ChatRole.USER,
|
||||
"content": message.content}
|
||||
content = message.content
|
||||
else:
|
||||
content = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
raise ValueError(
|
||||
'Content object type only support image_url')
|
||||
content.append(ChatCompletionContentPartTextParam(
|
||||
text=message_content.text,
|
||||
type='text',
|
||||
))
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(
|
||||
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
content.append({
|
||||
'type': 'image_url',
|
||||
'image_url': {
|
||||
'url': '',
|
||||
'image_bytes': image_data,
|
||||
'detail': message_content.detail,
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {'role': ChatRole.USER, 'content': content}
|
||||
content.append(ChatCompletionContentPartImageParam(
|
||||
image_url=ImageURL(
|
||||
url=image_data,
|
||||
detail=message_content.detail.value,
|
||||
),
|
||||
type='image_url',
|
||||
))
|
||||
message_dict = ChatCompletionUserMessageParam(
|
||||
role='user',
|
||||
content=content
|
||||
)
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.ASSISTANT,
|
||||
'content': message.content}
|
||||
if message.tool_calls:
|
||||
message_dict['tool_calls'] = [
|
||||
{
|
||||
'name': call.function.name,
|
||||
'arguments': call.function.arguments
|
||||
} for call in message.tool_calls
|
||||
message_dict = ChatCompletionAssistantMessageParam(
|
||||
content=message.content,
|
||||
role='assistant',
|
||||
tool_calls=None if not message.tool_calls else [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=call.id,
|
||||
function=Function(
|
||||
name=call.function.name,
|
||||
arguments=call.function.arguments
|
||||
),
|
||||
type='function'
|
||||
) for call in message.tool_calls
|
||||
]
|
||||
)
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.SYSTEM,
|
||||
'content': message.content}
|
||||
message_dict = ChatCompletionSystemMessageParam(
|
||||
content=message.content,
|
||||
role='system'
|
||||
)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.FUNCTION,
|
||||
'content': message.content,
|
||||
'name': message.tool_call_id}
|
||||
message_dict = ChatCompletionToolMessageParam(
|
||||
content=message.content,
|
||||
role='tool',
|
||||
tool_call_id=message.tool_call_id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||
try:
|
||||
resp = fn()
|
||||
except MaasException as e:
|
||||
raise wrap_error(e)
|
||||
def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam:
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=FunctionDefinition(
|
||||
name=message.name,
|
||||
description=message.description,
|
||||
parameters=message.parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return resp
|
||||
def chat(self, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> ChatCompletion:
|
||||
"""Block chat"""
|
||||
return self.ark.chat.completions.create(
|
||||
model=self.endpoint_id,
|
||||
messages=[self.convert_prompt_message(message) for message in messages],
|
||||
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
}
|
||||
def stream_chat(self, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> Generator[ChatCompletionChunk]:
|
||||
"""Stream chat"""
|
||||
chunks = self.ark.chat.completions.create(
|
||||
stream=True,
|
||||
model=self.endpoint_id,
|
||||
messages=[self.convert_prompt_message(message) for message in messages],
|
||||
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
)
|
||||
for chunk in chunks:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
yield chunk
|
||||
|
||||
def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse:
|
||||
return self.ark.embeddings.create(model=self.endpoint_id, input=texts)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
||||
|
||||
|
||||
class MaaSClient(MaasService):
|
||||
def __init__(self, host: str, region: str):
|
||||
self.endpoint_id = None
|
||||
super().__init__(host, region)
|
||||
|
||||
def set_endpoint_id(self, endpoint_id: str):
|
||||
self.endpoint_id = endpoint_id
|
||||
|
||||
@classmethod
|
||||
def from_credential(cls, credentials: dict) -> 'MaaSClient':
|
||||
host = credentials['api_endpoint_host']
|
||||
region = credentials['volc_region']
|
||||
ak = credentials['volc_access_key_id']
|
||||
sk = credentials['volc_secret_access_key']
|
||||
endpoint_id = credentials['endpoint_id']
|
||||
|
||||
client = cls(host, region)
|
||||
client.set_endpoint_id(endpoint_id)
|
||||
client.set_ak(ak)
|
||||
client.set_sk(sk)
|
||||
return client
|
||||
|
||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||
req = {
|
||||
'parameters': params,
|
||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||
**extra_model_kwargs,
|
||||
}
|
||||
if not stream:
|
||||
return super().chat(
|
||||
self.endpoint_id,
|
||||
req,
|
||||
)
|
||||
return super().stream_chat(
|
||||
self.endpoint_id,
|
||||
req,
|
||||
)
|
||||
|
||||
def embeddings(self, texts: list[str]) -> dict:
|
||||
req = {
|
||||
'input': texts
|
||||
}
|
||||
return super().embeddings(self.endpoint_id, req)
|
||||
|
||||
@staticmethod
|
||||
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": ChatRole.USER,
|
||||
"content": message.content}
|
||||
else:
|
||||
content = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
raise ValueError(
|
||||
'Content object type only support image_url')
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(
|
||||
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
content.append({
|
||||
'type': 'image_url',
|
||||
'image_url': {
|
||||
'url': '',
|
||||
'image_bytes': image_data,
|
||||
'detail': message_content.detail,
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {'role': ChatRole.USER, 'content': content}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.ASSISTANT,
|
||||
'content': message.content}
|
||||
if message.tool_calls:
|
||||
message_dict['tool_calls'] = [
|
||||
{
|
||||
'name': call.function.name,
|
||||
'arguments': call.function.arguments
|
||||
} for call in message.tool_calls
|
||||
]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.SYSTEM,
|
||||
'content': message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.FUNCTION,
|
||||
'content': message.content,
|
||||
'name': message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||
try:
|
||||
resp = fn()
|
||||
except MaasException as e:
|
||||
raise wrap_error(e)
|
||||
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
||||
|
||||
|
||||
class ClientSDKRequestError(MaasException):
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
|
@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
||||
from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
||||
get_model_config,
|
||||
get_v2_req_params,
|
||||
get_v3_req_params,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
# ping
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._validate_credentials_v2(credentials)
|
||||
return self._validate_credentials_v3(credentials)
|
||||
|
||||
@staticmethod
|
||||
def _validate_credentials_v2(credentials: dict) -> None:
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
try:
|
||||
client.chat(
|
||||
|
@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
except MaasException as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
@staticmethod
|
||||
def _validate_credentials_v3(credentials: dict) -> None:
|
||||
client = ArkClientV3.from_credentials(credentials)
|
||||
try:
|
||||
client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
|
||||
messages=[UserPromptMessage(content='ping\nAnswer: ')], )
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(e)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
if len(prompt_messages) == 0:
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._get_num_tokens_v2(prompt_messages)
|
||||
return self._get_num_tokens_v3(prompt_messages)
|
||||
|
||||
def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
|
||||
if len(messages) == 0:
|
||||
return 0
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
Calculate num tokens.
|
||||
|
||||
:param messages: messages
|
||||
"""
|
||||
num_tokens = 0
|
||||
messages_dict = [
|
||||
MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
||||
|
@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return num_tokens
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
|
||||
if len(messages) == 0:
|
||||
return 0
|
||||
num_tokens = 0
|
||||
messages_dict = [
|
||||
ArkClientV3.convert_prompt_message(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
for key, value in message.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
|
@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
]
|
||||
resp = MaaSClient.wrap_exception(
|
||||
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||
if not stream:
|
||||
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
||||
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
||||
|
||||
def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
|
||||
for index, r in enumerate(resp):
|
||||
choices = r['choices']
|
||||
def _handle_stream_chat_response() -> Generator:
|
||||
for index, r in enumerate(resp):
|
||||
choices = r['choices']
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
message = choice['message']
|
||||
usage = None
|
||||
if r.get('usage'):
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=r['usage']['prompt_tokens'],
|
||||
completion_tokens=r['usage']['completion_tokens']
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content=message['content'] if message['content'] else '',
|
||||
tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=choice.get('finish_reason'),
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_chat_response() -> LLMResult:
|
||||
choices = resp['choices']
|
||||
if not choices:
|
||||
continue
|
||||
raise ValueError("No choices found")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice['message']
|
||||
usage = None
|
||||
if r.get('usage'):
|
||||
usage = self._calc_usage(model, credentials, r['usage'])
|
||||
yield LLMResultChunk(
|
||||
|
||||
# parse tool calls
|
||||
tool_calls = []
|
||||
if message['tool_calls']:
|
||||
for call in message['tool_calls']:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=call['function']['name'],
|
||||
type=call['type'],
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=call['function']['name'],
|
||||
arguments=call['function']['arguments']
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
usage = resp['usage']
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content=message['content'] if message['content'] else '',
|
||||
tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=choice.get('finish_reason'),
|
||||
message=AssistantPromptMessage(
|
||||
content=message['content'] if message['content'] else '',
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=usage['prompt_tokens'],
|
||||
completion_tokens=usage['completion_tokens']
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
|
||||
choices = resp['choices']
|
||||
if not choices:
|
||||
return
|
||||
choice = choices[0]
|
||||
message = choice['message']
|
||||
if not stream:
|
||||
return _handle_chat_response()
|
||||
return _handle_stream_chat_response()
|
||||
|
||||
# parse tool calls
|
||||
tool_calls = []
|
||||
if message['tool_calls']:
|
||||
for call in message['tool_calls']:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=call['function']['name'],
|
||||
type=call['type'],
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=call['function']['name'],
|
||||
arguments=call['function']['arguments']
|
||||
)
|
||||
def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
client = ArkClientV3.from_credentials(credentials)
|
||||
req_params = get_v3_req_params(credentials, model_parameters, stop)
|
||||
if tools:
|
||||
req_params['tools'] = tools
|
||||
|
||||
def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
|
||||
for chunk in chunks:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=choice.index,
|
||||
message=AssistantPromptMessage(
|
||||
content=choice.delta.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens
|
||||
) if chunk.usage else None,
|
||||
finish_reason=choice.finish_reason,
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=message['content'] if message['content'] else '',
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_usage(model, credentials, resp['usage']),
|
||||
)
|
||||
def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
|
||||
choice = resp.choices[0]
|
||||
message = choice.message
|
||||
# parse tool calls
|
||||
tool_calls = []
|
||||
if message.tool_calls:
|
||||
for call in message.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=call.id,
|
||||
type=call.type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=call.function.name,
|
||||
arguments=call.function.arguments
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
|
||||
return self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=usage['prompt_tokens'],
|
||||
completion_tokens=usage['completion_tokens']
|
||||
)
|
||||
usage = resp.usage
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content if message.content else "",
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens
|
||||
),
|
||||
)
|
||||
|
||||
if not stream:
|
||||
resp = client.chat(prompt_messages, **req_params)
|
||||
return _handle_chat_response(resp)
|
||||
|
||||
chunks = client.stream_chat(prompt_messages, **req_params)
|
||||
return _handle_stream_chat_response(chunks)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
model_config = get_model_config(credentials)
|
||||
|
||||
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
|
@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
use_template='presence_penalty',
|
||||
label=I18nObject(
|
||||
en_US='Presence Penalty',
|
||||
zh_Hans= '存在惩罚',
|
||||
zh_Hans='存在惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
|
@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
type=ParameterType.FLOAT,
|
||||
use_template='frequency_penalty',
|
||||
label=I18nObject(
|
||||
en_US= 'Frequency Penalty',
|
||||
zh_Hans= '频率惩罚',
|
||||
en_US='Frequency Penalty',
|
||||
zh_Hans='频率惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
|
@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
|
||||
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
|
|
|
@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature
|
|||
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_tokens: int
|
||||
context_size: int
|
||||
max_tokens: int
|
||||
mode: LLMMode
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
features: list[ModelFeature]
|
||||
|
@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = {
|
|||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Skylark2-pro-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Llama3-8B': ModelConfig(
|
||||
|
@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = {
|
|||
)
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
|
||||
def get_model_config(credentials: dict) -> ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = configs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_tokens=int(credentials.get('max_tokens', 0)),
|
||||
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
|
||||
mode=LLMMode.value_of(credentials.get('mode', 'chat')),
|
||||
),
|
||||
features=[]
|
||||
)
|
||||
return model_configs
|
||||
|
||||
|
||||
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
||||
stop: list[str] | None=None):
|
||||
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
||||
stop: list[str] | None = None):
|
||||
req_params = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
|
@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict,
|
|||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
|
||||
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
return req_params
|
||||
return req_params
|
||||
|
||||
|
||||
def get_v3_req_params(credentials: dict, model_parameters: dict,
|
||||
stop: list[str] | None = None):
|
||||
req_params = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
if model_configs:
|
||||
req_params['max_tokens'] = model_configs.properties.max_tokens
|
||||
|
||||
# model parameters
|
||||
if model_parameters.get('max_tokens'):
|
||||
req_params['max_tokens'] = model_parameters.get('max_tokens')
|
||||
if model_parameters.get('temperature'):
|
||||
req_params['temperature'] = model_parameters.get('temperature')
|
||||
if model_parameters.get('top_p'):
|
||||
req_params['top_p'] = model_parameters.get('top_p')
|
||||
if model_parameters.get('presence_penalty'):
|
||||
req_params['presence_penalty'] = model_parameters.get(
|
||||
'presence_penalty')
|
||||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
return req_params
|
||||
|
|
|
@ -2,26 +2,29 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_chunks: int
|
||||
context_size: int
|
||||
max_chunks: int
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-embedding': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_chunks=1)
|
||||
properties=ModelProperties(context_size=4096, max_chunks=32)
|
||||
),
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
|
||||
def get_model_config(credentials: dict) -> ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = ModelConfigs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_chunks=int(credentials.get('max_chunks', 0)),
|
||||
)
|
||||
)
|
||||
return model_configs
|
||||
return model_configs
|
||||
|
|
|
@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
||||
from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
|
||||
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._generate_v2(model, credentials, texts, user)
|
||||
|
||||
return self._generate_v3(model, credentials, texts, user)
|
||||
|
||||
def _generate_v2(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
||||
|
||||
|
@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
return result
|
||||
|
||||
def _generate_v3(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
client = ArkClientV3.from_credentials(credentials)
|
||||
resp = client.embeddings(texts)
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, tokens=resp.usage.total_tokens)
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=[v.embedding for v in resp.data],
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
if ArkClientV3.is_legacy(credentials):
|
||||
return self._validate_credentials_v2(model, credentials)
|
||||
return self._validate_credentials_v3(model, credentials)
|
||||
|
||||
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except MaasException as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(e)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
|
@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||
generate custom model entities from credentials
|
||||
"""
|
||||
model_config = get_model_config(credentials)
|
||||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
|
||||
model_properties = {
|
||||
ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
|
||||
ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks
|
||||
}
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" fill="none" version="1.1" width="128" height="128" viewBox="0 0 128 128"><g><g style="opacity:0;"><rect x="0" y="0" width="128" height="128" rx="0" fill="#FFFFFF" fill-opacity="1"/></g><g><path d="M100.74,12L93.2335,12C69.21260000000001,12,55.3672,27.3468,55.3672,50.8672L55.3672,54.8988C52.6011,54.1056,49.7377,53.7031,46.8601,53.7031C29.816499999999998,53.7031,16,67.5196,16,84.5632C16,101.6069,29.816499999999998,115.423,46.8601,115.423C63.9037,115.423,77.72030000000001,101.6069,77.72030000000001,84.5632C77.72030000000001,82.4902,77.51140000000001,80.4223,77.0967,78.3911L77.2197,78.3911L100.74,78.3911C106.9654,78.3681,112,73.3151,112,67.08959999999999C112,60.8642,106.9654,55.8111,100.74,55.7882L100.7362,55.7882L100.6985,55.7879L100.6606,55.7882L77.2197,55.7882L77.2195,49.8663C77.2195,40.8584,83.7252,34.352900000000005,93.2335,34.352900000000005L100.5653,34.352900000000005L100.5733,34.352900000000005L100.5812,34.352900000000005L100.74,34.352900000000005L100.74,34.352900000000005C106.8469,34.2605,111.7497,29.284,111.7497,23.1764C111.7497,17.06889,106.8469,12.0923454,100.74,12L100.74,12ZM56.0347,84.5632C56.0347,79.4962,51.9271,75.3885,46.8601,75.3885C41.793099999999995,75.3885,37.6854,79.4962,37.6854,84.5632C37.6854,89.6303,41.793099999999995,93.7378,46.8601,93.7378C51.9271,93.7378,56.0347,89.6303,56.0347,84.5632Z" fill-rule="evenodd" fill="#8358F6" fill-opacity="1"/></g></g></svg>
|
After Width: | Height: | Size: 1.4 KiB |
19
api/core/tools/provider/builtin/siliconflow/siliconflow.py
Normal file
19
api/core/tools/provider/builtin/siliconflow/siliconflow.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SiliconflowProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
url = "https://api.siliconflow.cn/v1/models"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"authorization": f"Bearer {credentials.get('siliconFlow_api_key')}",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
"SiliconFlow API key is invalid"
|
||||
)
|
21
api/core/tools/provider/builtin/siliconflow/siliconflow.yaml
Normal file
21
api/core/tools/provider/builtin/siliconflow/siliconflow.yaml
Normal file
|
@ -0,0 +1,21 @@
|
|||
identity:
|
||||
author: hjlarry
|
||||
name: siliconflow
|
||||
label:
|
||||
en_US: SiliconFlow
|
||||
zh_CN: 硅基流动
|
||||
description:
|
||||
en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models.
|
||||
zh_CN: 硅基流动提供的图片生成 API,包含 Flux 和 Stable Diffusion 模型。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
siliconFlow_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: SiliconFlow API Key
|
||||
placeholder:
|
||||
en_US: Please input your SiliconFlow API key
|
||||
url: https://cloud.siliconflow.cn/account/ak
|
44
api/core/tools/provider/builtin/siliconflow/tools/flux.py
Normal file
44
api/core/tools/provider/builtin/siliconflow/tools/flux.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
FLUX_URL = (
|
||||
"https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image"
|
||||
)
|
||||
|
||||
|
||||
class FluxTool(BuiltinTool):
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"prompt": tool_parameters.get("prompt"),
|
||||
"image_size": tool_parameters.get("image_size", "1024x1024"),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 20),
|
||||
}
|
||||
|
||||
response = requests.post(FLUX_URL, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
for image in res.get("images", []):
|
||||
result.append(
|
||||
self.create_image_message(
|
||||
image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
||||
)
|
||||
return result
|
73
api/core/tools/provider/builtin/siliconflow/tools/flux.yaml
Normal file
73
api/core/tools/provider/builtin/siliconflow/tools/flux.yaml
Normal file
|
@ -0,0 +1,73 @@
|
|||
identity:
|
||||
name: flux
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Flux
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate image via SiliconFlow's flux schnell.
|
||||
llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图片的文字提示词
|
||||
llm_description: this prompt text will be used to generate image.
|
||||
form: llm
|
||||
- name: image_size
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: 1024x1024
|
||||
label:
|
||||
en_US: 1024x1024
|
||||
- value: 768x1024
|
||||
label:
|
||||
en_US: 768x1024
|
||||
- value: 576x1024
|
||||
label:
|
||||
en_US: 576x1024
|
||||
- value: 512x1024
|
||||
label:
|
||||
en_US: 512x1024
|
||||
- value: 1024x576
|
||||
label:
|
||||
en_US: 1024x576
|
||||
- value: 768x512
|
||||
label:
|
||||
en_US: 768x512
|
||||
default: 1024x1024
|
||||
label:
|
||||
en_US: Choose Image Size
|
||||
zh_Hans: 选择生成的图片大小
|
||||
form: form
|
||||
- name: num_inference_steps
|
||||
type: number
|
||||
required: true
|
||||
default: 20
|
||||
min: 1
|
||||
max: 100
|
||||
label:
|
||||
en_US: Num Inference Steps
|
||||
zh_Hans: 生成图片的步数
|
||||
form: form
|
||||
human_description:
|
||||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
|
||||
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
|
||||
- name: seed
|
||||
type: number
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示可以产生相似的图像。
|
||||
form: form
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
SDURL = {
|
||||
"sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image",
|
||||
"sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image",
|
||||
}
|
||||
|
||||
|
||||
class StableDiffusionTool(BuiltinTool):
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}",
|
||||
}
|
||||
|
||||
model = tool_parameters.get("model", "sd_3")
|
||||
url = SDURL.get(model)
|
||||
|
||||
payload = {
|
||||
"prompt": tool_parameters.get("prompt"),
|
||||
"negative_prompt": tool_parameters.get("negative_prompt", ""),
|
||||
"image_size": tool_parameters.get("image_size", "1024x1024"),
|
||||
"batch_size": tool_parameters.get("batch_size", 1),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"guidance_scale": tool_parameters.get("guidance_scale", 7.5),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 20),
|
||||
}
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
for image in res.get("images", []):
|
||||
result.append(
|
||||
self.create_image_message(
|
||||
image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
||||
)
|
||||
return result
|
|
@ -0,0 +1,121 @@
|
|||
identity:
|
||||
name: stable_diffusion
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Stable Diffusion
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate image via SiliconFlow's stable diffusion model.
|
||||
llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图片的文字提示词
|
||||
llm_description: this prompt text will be used to generate image.
|
||||
form: llm
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
label:
|
||||
en_US: negative prompt
|
||||
zh_Hans: 负面提示词
|
||||
human_description:
|
||||
en_US: Describe what you don't want included in the image.
|
||||
zh_Hans: 描述您不希望包含在图片中的内容。
|
||||
llm_description: Describe what you don't want included in the image.
|
||||
form: llm
|
||||
- name: model
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: sd_3
|
||||
label:
|
||||
en_US: Stable Diffusion 3
|
||||
- value: sd_xl
|
||||
label:
|
||||
en_US: Stable Diffusion XL
|
||||
default: sd_3
|
||||
label:
|
||||
en_US: Choose Image Model
|
||||
zh_Hans: 选择生成图片的模型
|
||||
form: form
|
||||
- name: image_size
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: 1024x1024
|
||||
label:
|
||||
en_US: 1024x1024
|
||||
- value: 1024x2048
|
||||
label:
|
||||
en_US: 1024x2048
|
||||
- value: 1152x2048
|
||||
label:
|
||||
en_US: 1152x2048
|
||||
- value: 1536x1024
|
||||
label:
|
||||
en_US: 1536x1024
|
||||
- value: 1536x2048
|
||||
label:
|
||||
en_US: 1536x2048
|
||||
- value: 2048x1152
|
||||
label:
|
||||
en_US: 2048x1152
|
||||
default: 1024x1024
|
||||
label:
|
||||
en_US: Choose Image Size
|
||||
zh_Hans: 选择生成图片的大小
|
||||
form: form
|
||||
- name: batch_size
|
||||
type: number
|
||||
required: true
|
||||
default: 1
|
||||
min: 1
|
||||
max: 4
|
||||
label:
|
||||
en_US: Number Images
|
||||
zh_Hans: 生成图片的数量
|
||||
form: form
|
||||
- name: guidance_scale
|
||||
type: number
|
||||
required: true
|
||||
default: 7
|
||||
min: 0
|
||||
max: 100
|
||||
label:
|
||||
en_US: Guidance Scale
|
||||
zh_Hans: 与提示词紧密性
|
||||
human_description:
|
||||
en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you.
|
||||
zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。
|
||||
form: form
|
||||
- name: num_inference_steps
|
||||
type: number
|
||||
required: true
|
||||
default: 20
|
||||
min: 1
|
||||
max: 100
|
||||
label:
|
||||
en_US: Num Inference Steps
|
||||
zh_Hans: 生成图片的步数
|
||||
human_description:
|
||||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
|
||||
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
|
||||
form: form
|
||||
- name: seed
|
||||
type: number
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示可以产生相似的图像。
|
||||
form: form
|
37
api/poetry.lock
generated
37
api/poetry.lock
generated
|
@ -6143,6 +6143,19 @@ files = [
|
|||
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
|
||||
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
|
||||
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
|
||||
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -8854,6 +8867,28 @@ files = [
|
|||
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "volcengine-python-sdk"
|
||||
version = "1.0.98"
|
||||
description = "Volcengine SDK for Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""}
|
||||
certifi = ">=2017.4.17"
|
||||
httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""}
|
||||
pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""}
|
||||
python-dateutil = ">=2.1"
|
||||
six = ">=1.10"
|
||||
urllib3 = ">=1.23"
|
||||
|
||||
[package.extras]
|
||||
ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "0.23.0"
|
||||
|
@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
|
||||
content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243"
|
||||
|
|
|
@ -191,6 +191,7 @@ zhipuai = "1.0.7"
|
|||
# Related transparent dependencies with pinned verion
|
||||
# required by main implementations
|
||||
############################################################
|
||||
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
||||
[tool.poetry.group.indriect.dependencies]
|
||||
kaleido = "0.2.1"
|
||||
rank-bm25 = "~0.2.2"
|
||||
|
|
|
@ -34,16 +34,17 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
|
|||
const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures)
|
||||
|
||||
const updateAppDetail = async () => {
|
||||
fetchAppDetail({ url: '/apps', id: appId }).then((res) => {
|
||||
try {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
if (systemFeatures.enable_web_sso_switch_component) {
|
||||
fetchAppSSO({ appId }).then((ssoRes) => {
|
||||
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
|
||||
})
|
||||
const ssoRes = await fetchAppSSO({ appId })
|
||||
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
|
||||
}
|
||||
else {
|
||||
setAppDetail({ ...res })
|
||||
}
|
||||
})
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
}
|
||||
|
||||
const handleCallbackResult = (err: Error | null, message?: string) => {
|
||||
|
|
|
@ -43,7 +43,7 @@ export type ConfigParams = {
|
|||
icon: string
|
||||
icon_background?: string
|
||||
show_workflow_steps: boolean
|
||||
enable_sso: boolean
|
||||
enable_sso?: boolean
|
||||
}
|
||||
|
||||
const prefixSettings = 'appOverview.overview.appInfo.settings'
|
||||
|
@ -157,7 +157,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
|||
icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId,
|
||||
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
|
||||
show_workflow_steps: inputInfo.show_workflow_steps,
|
||||
enable_sso: inputInfo.enable_sso!,
|
||||
enable_sso: inputInfo.enable_sso,
|
||||
}
|
||||
await onSave?.(params)
|
||||
setSaveLoading(false)
|
||||
|
@ -235,7 +235,11 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
|||
<p className='system-xs-medium text-gray-500'>{t(`${prefixSettings}.sso.label`)}</p>
|
||||
<div className='flex justify-between items-center'>
|
||||
<div className='font-medium system-sm-semibold flex-grow text-gray-900'>{t(`${prefixSettings}.sso.title`)}</div>
|
||||
<Tooltip asChild={false} disabled={systemFeatures.sso_enforced_for_web} popupContent={<div className='w-[180px]'>{t(`${prefixSettings}.sso.tooltip`)}</div>}>
|
||||
<Tooltip
|
||||
disabled={systemFeatures.sso_enforced_for_web}
|
||||
popupContent={<div className='w-[180px]'>{t(`${prefixSettings}.sso.tooltip`)}</div>}
|
||||
asChild={false}
|
||||
>
|
||||
<Switch disabled={!systemFeatures.sso_enforced_for_web} defaultValue={systemFeatures.sso_enforced_for_web && inputInfo.enable_sso} onChange={v => setInputInfo({ ...inputInfo, enable_sso: v })}></Switch>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
|
|
@ -89,7 +89,7 @@ const Tooltip: FC<TooltipProps> = ({
|
|||
onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)}
|
||||
asChild={asChild}
|
||||
>
|
||||
{children || <div className={triggerClassName || 'p-[1px] w-3.5 h-3.5'}><RiQuestionLine className='text-text-quaternary hover:text-text-tertiary w-full h-full' /></div>}
|
||||
{children || <div className={triggerClassName || 'p-[1px] w-3.5 h-3.5 shrink-0'}><RiQuestionLine className='text-text-quaternary hover:text-text-tertiary w-full h-full' /></div>}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent
|
||||
className="z-[9999]"
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import { Fragment, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
|
||||
import { ValidatingTip } from '../../key-validator/ValidateStatus'
|
||||
import type {
|
||||
CredentialFormSchema,
|
||||
|
|
|
@ -242,13 +242,14 @@ const ParameterItem: FC<ParameterItemProps> = ({
|
|||
<div className='w-[200px] whitespace-pre-wrap'>{parameterRule.help[language] || parameterRule.help.en_US}</div>
|
||||
)}
|
||||
popupClassName='mr-1'
|
||||
triggerClassName='mr-1 w-4 h-4'
|
||||
triggerClassName='mr-1 w-4 h-4 shrink-0'
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!parameterRule.required && parameterRule.name !== 'stop' && (
|
||||
<Switch
|
||||
className='mr-1'
|
||||
defaultValue={!isNullOrUndefined(value)}
|
||||
onChange={handleSwitch}
|
||||
size='md'
|
||||
|
|
|
@ -3,9 +3,8 @@ import type { FC } from 'react'
|
|||
import React, { useCallback } from 'react'
|
||||
import type { VariantProps } from 'class-variance-authority'
|
||||
import { cva } from 'class-variance-authority'
|
||||
import { RiQuestionLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import TooltipPlus from '@/app/components/base/tooltip'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
const variants = cva([], {
|
||||
variants: {
|
||||
|
@ -59,15 +58,16 @@ const OptionCard: FC<Props> = ({
|
|||
onClick={handleSelect}
|
||||
>
|
||||
<span>{title}</span>
|
||||
{tooltip && <TooltipPlus
|
||||
popupContent={<div className='w-[240px]'
|
||||
>
|
||||
{tooltip}
|
||||
</div>}
|
||||
asChild={false}
|
||||
>
|
||||
<RiQuestionLine className='ml-0.5 w-[14px] h-[14px] text-text-quaternary' />
|
||||
</TooltipPlus>}
|
||||
{tooltip
|
||||
&& <Tooltip
|
||||
popupContent={
|
||||
<div className='w-[240px]'>
|
||||
{tooltip}
|
||||
</div>
|
||||
}
|
||||
asChild={false}
|
||||
/>
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ const translation = {
|
|||
chatColorThemeInverted: 'Inverted',
|
||||
invalidHexMessage: 'Invalid hex value',
|
||||
sso: {
|
||||
label: 'SSO ENFORCEMENT',
|
||||
label: 'SSO Authentication',
|
||||
title: 'WebApp SSO',
|
||||
description: 'All users are required to login with SSO before using WebApp',
|
||||
tooltip: 'Contact the administrator to enable WebApp SSO',
|
||||
|
|
|
@ -16,7 +16,7 @@ export const fetchAppDetail = ({ url, id }: { url: string; id: string }) => {
|
|||
export const fetchAppSSO = async ({ appId }: { appId: string }) => {
|
||||
return get<AppSSOResponse>(`/enterprise/app-setting/sso?appID=${appId}`)
|
||||
}
|
||||
export const updateAppSSO = async ({ id, enabled }: { id: string;enabled: boolean }) => {
|
||||
export const updateAppSSO = async ({ id, enabled }: { id: string; enabled: boolean }) => {
|
||||
return post('/enterprise/app-setting/sso', { body: { app_id: id, enabled } })
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user