Merge branch 'feat/tooltip' into deploy/dev
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions

This commit is contained in:
Yi 2024-08-26 10:47:46 +08:00
commit 2ba2594718
37 changed files with 980 additions and 226 deletions

View File

@ -92,7 +92,7 @@ class AppMetaApi(Resource):
class AppInfoApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app infomation"""
"""Get app information"""
return {
'name':app_model.name,
'description':app_model.description

View File

@ -1,6 +1,6 @@
import base64
import io
import json
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -18,6 +18,7 @@ from anthropic.types import (
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
@ -462,7 +463,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")

View File

@ -1,8 +1,8 @@
# standard import
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -17,6 +17,7 @@ from botocore.exceptions import (
ServiceNotInRegionError,
UnknownServiceError,
)
from PIL.Image import Image
# local import
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
@ -381,9 +382,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try:
url = message_content.data
image_content = requests.get(url).content
if '?' in url:
url = url.split('?')[0]
mime_type, _ = mimetypes.guess_type(url)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")

View File

@ -1,7 +1,7 @@
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
@ -12,6 +12,7 @@ import google.generativeai.client as client
import requests
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@ -371,7 +372,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")

View File

@ -1,4 +1,5 @@
import base64
import io
import json
import logging
from collections.abc import Generator
@ -18,6 +19,7 @@ from anthropic.types import (
)
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -332,7 +334,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")

View File

@ -1,6 +1,25 @@
import re
from collections.abc import Callable, Generator
from typing import cast
from collections.abc import Generator
from typing import Optional, cast
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
class MaaSClient(MaasService):
def __init__(self, host: str, region: str):
class ArkClientV3:
endpoint_id: Optional[str] = None
ark: Optional[Ark] = None
def __init__(self, *args, **kwargs):
self.ark = Ark(*args, **kwargs)
self.endpoint_id = None
super().__init__(host, region)
def set_endpoint_id(self, endpoint_id: str):
self.endpoint_id = endpoint_id
@classmethod
def from_credential(cls, credentials: dict) -> 'MaaSClient':
host = credentials['api_endpoint_host']
region = credentials['volc_region']
ak = credentials['volc_access_key_id']
sk = credentials['volc_secret_access_key']
endpoint_id = credentials['endpoint_id']
client = cls(host, region)
client.set_endpoint_id(endpoint_id)
client.set_ak(ak)
client.set_sk(sk)
return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = {
'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
}
if not stream:
return super().chat(
self.endpoint_id,
req,
)
return super().stream_chat(
self.endpoint_id,
req,
)
def embeddings(self, texts: list[str]) -> dict:
req = {
'input': texts
}
return super().embeddings(self.endpoint_id, req)
@staticmethod
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
def is_legacy(credentials: dict) -> bool:
if ArkClientV3.is_compatible_with_legacy(credentials):
return False
sdk_version = credentials.get("sdk_version", "v2")
return sdk_version != "v3"
@staticmethod
def is_compatible_with_legacy(credentials: dict) -> bool:
sdk_version = credentials.get("sdk_version")
endpoint = credentials.get("api_endpoint_host")
return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
@classmethod
def from_credentials(cls, credentials):
"""Initialize the client using the credentials provided."""
args = {
"base_url": credentials['api_endpoint_host'],
"region": credentials['volc_region'],
"ak": credentials['volc_access_key_id'],
"sk": credentials['volc_secret_access_key'],
}
if cls.is_compatible_with_legacy(credentials):
args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
client = ArkClientV3(
**args
)
client.endpoint_id = credentials['endpoint_id']
return client
@staticmethod
def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam:
"""Converts a PromptMessage to a ChatCompletionMessageParam"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": ChatRole.USER,
"content": message.content}
content = message.content
else:
content = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
raise ValueError(
'Content object type only support image_url')
content.append(ChatCompletionContentPartTextParam(
text=message_content.text,
type='text',
))
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content)
image_data = re.sub(
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
content.append({
'type': 'image_url',
'image_url': {
'url': '',
'image_bytes': image_data,
'detail': message_content.detail,
}
})
message_dict = {'role': ChatRole.USER, 'content': content}
content.append(ChatCompletionContentPartImageParam(
image_url=ImageURL(
url=image_data,
detail=message_content.detail.value,
),
type='image_url',
))
message_dict = ChatCompletionUserMessageParam(
role='user',
content=content
)
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
message_dict = ChatCompletionAssistantMessageParam(
content=message.content,
role='assistant',
tool_calls=None if not message.tool_calls else [
ChatCompletionMessageToolCallParam(
id=call.id,
function=Function(
name=call.function.name,
arguments=call.function.arguments
),
type='function'
) for call in message.tool_calls
]
)
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM,
'content': message.content}
message_dict = ChatCompletionSystemMessageParam(
content=message.content,
role='system'
)
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
message_dict = ChatCompletionToolMessageParam(
content=message.content,
role='tool',
tool_call_id=message.tool_call_id
)
else:
raise ValueError(f"Got unknown PromptMessage type {message}")
return message_dict
@staticmethod
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
raise wrap_error(e)
def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam:
return ChatCompletionToolParam(
type='function',
function=FunctionDefinition(
name=message.name,
description=message.description,
parameters=message.parameters,
)
)
return resp
def chat(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> ChatCompletion:
"""Block chat"""
return self.ark.chat.completions.create(
model=self.endpoint_id,
messages=[self.convert_prompt_message(message) for message in messages],
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
stop=stop,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
top_p=top_p,
temperature=temperature,
)
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}
def stream_chat(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> Generator[ChatCompletionChunk]:
"""Stream chat"""
chunks = self.ark.chat.completions.create(
stream=True,
model=self.endpoint_id,
messages=[self.convert_prompt_message(message) for message in messages],
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
stop=stop,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
top_p=top_p,
temperature=temperature,
)
for chunk in chunks:
if not chunk.choices:
continue
yield chunk
def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse:
return self.ark.embeddings.create(model=self.endpoint_id, input=texts)

View File

@ -0,0 +1,134 @@
import re
from collections.abc import Callable, Generator
from typing import cast
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
class MaaSClient(MaasService):
def __init__(self, host: str, region: str):
self.endpoint_id = None
super().__init__(host, region)
def set_endpoint_id(self, endpoint_id: str):
self.endpoint_id = endpoint_id
@classmethod
def from_credential(cls, credentials: dict) -> 'MaaSClient':
host = credentials['api_endpoint_host']
region = credentials['volc_region']
ak = credentials['volc_access_key_id']
sk = credentials['volc_secret_access_key']
endpoint_id = credentials['endpoint_id']
client = cls(host, region)
client.set_endpoint_id(endpoint_id)
client.set_ak(ak)
client.set_sk(sk)
return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = {
'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
}
if not stream:
return super().chat(
self.endpoint_id,
req,
)
return super().stream_chat(
self.endpoint_id,
req,
)
def embeddings(self, texts: list[str]) -> dict:
req = {
'input': texts
}
return super().embeddings(self.endpoint_id, req)
@staticmethod
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": ChatRole.USER,
"content": message.content}
else:
content = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
raise ValueError(
'Content object type only support image_url')
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content)
image_data = re.sub(
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
content.append({
'type': 'image_url',
'image_url': {
'url': '',
'image_bytes': image_data,
'detail': message_content.detail,
}
})
message_dict = {'role': ChatRole.USER, 'content': content}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM,
'content': message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
else:
raise ValueError(f"Got unknown PromptMessage type {message}")
return message_dict
@staticmethod
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
raise wrap_error(e)
return resp
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}

View File

@ -1,4 +1,4 @@
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
class ClientSDKRequestError(MaasException):

View File

@ -1,8 +1,10 @@
import logging
from collections.abc import Generator
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.errors import (
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
get_v3_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
if ArkClientV3.is_legacy(credentials):
return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate credentials
"""
# ping
if ArkClientV3.is_legacy(credentials):
return self._validate_credentials_v2(credentials)
return self._validate_credentials_v3(credentials)
@staticmethod
def _validate_credentials_v2(credentials: dict) -> None:
client = MaaSClient.from_credential(credentials)
try:
client.chat(
@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
except MaasException as e:
raise CredentialsValidateFailedError(e.message)
@staticmethod
def _validate_credentials_v3(credentials: dict) -> None:
client = ArkClientV3.from_credentials(credentials)
try:
client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
messages=[UserPromptMessage(content='ping\nAnswer: ')], )
except Exception as e:
raise CredentialsValidateFailedError(e)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
if len(prompt_messages) == 0:
if ArkClientV3.is_legacy(credentials):
return self._get_num_tokens_v2(prompt_messages)
return self._get_num_tokens_v3(prompt_messages)
def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
if len(messages) == 0:
return 0
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
"""
Calculate num tokens.
:param messages: messages
"""
num_tokens = 0
messages_dict = [
MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
if len(messages) == 0:
return 0
num_tokens = 0
messages_dict = [
ArkClientV3.convert_prompt_message(m) for m in messages]
for message in messages_dict:
for key, value in message.items():
num_tokens += self._get_num_tokens_by_gpt2(str(key))
num_tokens += self._get_num_tokens_by_gpt2(str(value))
return num_tokens
def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
return self._handle_chat_response(model, credentials, prompt_messages, resp)
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
for index, r in enumerate(resp):
choices = r['choices']
def _handle_stream_chat_response() -> Generator:
for index, r in enumerate(resp):
choices = r['choices']
if not choices:
continue
choice = choices[0]
message = choice['message']
usage = None
if r.get('usage'):
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=r['usage']['prompt_tokens'],
completion_tokens=r['usage']['completion_tokens']
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=[]
),
usage=usage,
finish_reason=choice.get('finish_reason'),
),
)
def _handle_chat_response() -> LLMResult:
choices = resp['choices']
if not choices:
continue
raise ValueError("No choices found")
choice = choices[0]
message = choice['message']
usage = None
if r.get('usage'):
usage = self._calc_usage(model, credentials, r['usage'])
yield LLMResultChunk(
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
)
tool_calls.append(tool_call)
usage = resp['usage']
return LLMResult(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=[]
),
usage=usage,
finish_reason=choice.get('finish_reason'),
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=tool_calls,
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage['prompt_tokens'],
completion_tokens=usage['completion_tokens']
),
)
def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
choices = resp['choices']
if not choices:
return
choice = choices[0]
message = choice['message']
if not stream:
return _handle_chat_response()
return _handle_stream_chat_response()
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
client = ArkClientV3.from_credentials(credentials)
req_params = get_v3_req_params(credentials, model_parameters, stop)
if tools:
req_params['tools'] = tools
def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
for chunk in chunks:
if not chunk.choices:
continue
choice = chunk.choices[0]
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=choice.index,
message=AssistantPromptMessage(
content=choice.delta.content,
tool_calls=[]
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens
) if chunk.usage else None,
finish_reason=choice.finish_reason,
),
)
tool_calls.append(tool_call)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=tool_calls,
),
usage=self._calc_usage(model, credentials, resp['usage']),
)
def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
choice = resp.choices[0]
message = choice.message
# parse tool calls
tool_calls = []
if message.tool_calls:
for call in message.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=call.id,
type=call.type,
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call.function.name,
arguments=call.function.arguments
)
)
tool_calls.append(tool_call)
def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
return self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage['prompt_tokens'],
completion_tokens=usage['completion_tokens']
)
usage = resp.usage
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message.content if message.content else "",
tool_calls=tool_calls,
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens
),
)
if not stream:
resp = client.chat(prompt_messages, **req_params)
return _handle_chat_response(resp)
chunks = client.stream_chat(prompt_messages, **req_params)
return _handle_stream_chat_response(chunks)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
model_config = get_model_config(credentials)
rules = [
ParameterRule(
name='temperature',
@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
use_template='presence_penalty',
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
zh_Hans='存在惩罚',
),
min=-2.0,
max=2.0,
@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
en_US='Frequency Penalty',
zh_Hans='频率惩罚',
),
min=-2.0,
max=2.0,
@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
entity = AIModelEntity(
model=model,
label=I18nObject(

View File

@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = {
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = {
)
}
def get_model_config(credentials: dict)->ModelConfig:
def get_model_config(credentials: dict) -> ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
mode=LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None = None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict,
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params
return req_params
def get_v3_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None = None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -2,26 +2,29 @@ from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
properties=ModelProperties(context_size=4096, max_chunks=32)
),
}
def get_model_config(credentials: dict)->ModelConfig:
def get_model_config(credentials: dict) -> ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs
return model_configs

View File

@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.errors import (
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
if ArkClientV3.is_legacy(credentials):
return self._generate_v2(model, credentials, texts, user)
return self._generate_v3(model, credentials, texts, user)
def _generate_v2(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
client = MaaSClient.from_credential(credentials)
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
return result
def _generate_v3(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
client = ArkClientV3.from_credentials(credentials)
resp = client.embeddings(texts)
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=resp.usage.total_tokens)
result = TextEmbeddingResult(
model=model,
embeddings=[v.embedding for v in resp.data],
usage=usage
)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
if ArkClientV3.is_legacy(credentials):
return self._validate_credentials_v2(model, credentials)
return self._validate_credentials_v3(model, credentials)
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
except MaasException as e:
raise CredentialsValidateFailedError(e.message)
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
except Exception as e:
raise CredentialsValidateFailedError(e)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
generate custom model entities from credentials
"""
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
model_properties = {
ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks
}
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" fill="none" version="1.1" width="128" height="128" viewBox="0 0 128 128"><g><g style="opacity:0;"><rect x="0" y="0" width="128" height="128" rx="0" fill="#FFFFFF" fill-opacity="1"/></g><g><path d="M100.74,12L93.2335,12C69.21260000000001,12,55.3672,27.3468,55.3672,50.8672L55.3672,54.8988C52.6011,54.1056,49.7377,53.7031,46.8601,53.7031C29.816499999999998,53.7031,16,67.5196,16,84.5632C16,101.6069,29.816499999999998,115.423,46.8601,115.423C63.9037,115.423,77.72030000000001,101.6069,77.72030000000001,84.5632C77.72030000000001,82.4902,77.51140000000001,80.4223,77.0967,78.3911L77.2197,78.3911L100.74,78.3911C106.9654,78.3681,112,73.3151,112,67.08959999999999C112,60.8642,106.9654,55.8111,100.74,55.7882L100.7362,55.7882L100.6985,55.7879L100.6606,55.7882L77.2197,55.7882L77.2195,49.8663C77.2195,40.8584,83.7252,34.352900000000005,93.2335,34.352900000000005L100.5653,34.352900000000005L100.5733,34.352900000000005L100.5812,34.352900000000005L100.74,34.352900000000005L100.74,34.352900000000005C106.8469,34.2605,111.7497,29.284,111.7497,23.1764C111.7497,17.06889,106.8469,12.0923454,100.74,12L100.74,12ZM56.0347,84.5632C56.0347,79.4962,51.9271,75.3885,46.8601,75.3885C41.793099999999995,75.3885,37.6854,79.4962,37.6854,84.5632C37.6854,89.6303,41.793099999999995,93.7378,46.8601,93.7378C51.9271,93.7378,56.0347,89.6303,56.0347,84.5632Z" fill-rule="evenodd" fill="#8358F6" fill-opacity="1"/></g></g></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -0,0 +1,19 @@
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class SiliconflowProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
url = "https://api.siliconflow.cn/v1/models"
headers = {
"accept": "application/json",
"authorization": f"Bearer {credentials.get('siliconFlow_api_key')}",
}
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(
"SiliconFlow API key is invalid"
)

View File

@ -0,0 +1,21 @@
identity:
author: hjlarry
name: siliconflow
label:
en_US: SiliconFlow
zh_CN: 硅基流动
description:
en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models.
zh_CN: 硅基流动提供的图片生成 API包含 Flux 和 Stable Diffusion 模型。
icon: icon.svg
tags:
- image
credentials_for_provider:
siliconFlow_api_key:
type: secret-input
required: true
label:
en_US: SiliconFlow API Key
placeholder:
en_US: Please input your SiliconFlow API key
url: https://cloud.siliconflow.cn/account/ak

View File

@ -0,0 +1,44 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
FLUX_URL = (
"https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image"
)
class FluxTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}",
}
payload = {
"prompt": tool_parameters.get("prompt"),
"image_size": tool_parameters.get("image_size", "1024x1024"),
"seed": tool_parameters.get("seed"),
"num_inference_steps": tool_parameters.get("num_inference_steps", 20),
}
response = requests.post(FLUX_URL, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
res = response.json()
result = [self.create_json_message(res)]
for image in res.get("images", []):
result.append(
self.create_image_message(
image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value
)
)
return result

View File

@ -0,0 +1,73 @@
identity:
name: flux
author: hjlarry
label:
en_US: Flux
icon: icon.svg
description:
human:
en_US: Generate image via SiliconFlow's flux schnell.
llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to generate the image.
zh_Hans: 用于生成图片的文字提示词
llm_description: this prompt text will be used to generate image.
form: llm
- name: image_size
type: select
required: true
options:
- value: 1024x1024
label:
en_US: 1024x1024
- value: 768x1024
label:
en_US: 768x1024
- value: 576x1024
label:
en_US: 576x1024
- value: 512x1024
label:
en_US: 512x1024
- value: 1024x576
label:
en_US: 1024x576
- value: 768x512
label:
en_US: 768x512
default: 1024x1024
label:
en_US: Choose Image Size
zh_Hans: 选择生成的图片大小
form: form
- name: num_inference_steps
type: number
required: true
default: 20
min: 1
max: 100
label:
en_US: Num Inference Steps
zh_Hans: 生成图片的步数
form: form
human_description:
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
- name: seed
type: number
min: 0
max: 9999999999
label:
en_US: Seed
zh_Hans: 种子
human_description:
en_US: The same seed and prompt can produce similar images.
zh_Hans: 相同的种子和提示可以产生相似的图像。
form: form

View File

@ -0,0 +1,51 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
SDURL = {
"sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image",
"sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image",
}
class StableDiffusionTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}",
}
model = tool_parameters.get("model", "sd_3")
url = SDURL.get(model)
payload = {
"prompt": tool_parameters.get("prompt"),
"negative_prompt": tool_parameters.get("negative_prompt", ""),
"image_size": tool_parameters.get("image_size", "1024x1024"),
"batch_size": tool_parameters.get("batch_size", 1),
"seed": tool_parameters.get("seed"),
"guidance_scale": tool_parameters.get("guidance_scale", 7.5),
"num_inference_steps": tool_parameters.get("num_inference_steps", 20),
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
res = response.json()
result = [self.create_json_message(res)]
for image in res.get("images", []):
result.append(
self.create_image_message(
image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value
)
)
return result

View File

@ -0,0 +1,121 @@
identity:
name: stable_diffusion
author: hjlarry
label:
en_US: Stable Diffusion
icon: icon.svg
description:
human:
en_US: Generate image via SiliconFlow's stable diffusion model.
llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to generate the image.
zh_Hans: 用于生成图片的文字提示词
llm_description: this prompt text will be used to generate image.
form: llm
- name: negative_prompt
type: string
label:
en_US: negative prompt
zh_Hans: 负面提示词
human_description:
en_US: Describe what you don't want included in the image.
zh_Hans: 描述您不希望包含在图片中的内容。
llm_description: Describe what you don't want included in the image.
form: llm
- name: model
type: select
required: true
options:
- value: sd_3
label:
en_US: Stable Diffusion 3
- value: sd_xl
label:
en_US: Stable Diffusion XL
default: sd_3
label:
en_US: Choose Image Model
zh_Hans: 选择生成图片的模型
form: form
- name: image_size
type: select
required: true
options:
- value: 1024x1024
label:
en_US: 1024x1024
- value: 1024x2048
label:
en_US: 1024x2048
- value: 1152x2048
label:
en_US: 1152x2048
- value: 1536x1024
label:
en_US: 1536x1024
- value: 1536x2048
label:
en_US: 1536x2048
- value: 2048x1152
label:
en_US: 2048x1152
default: 1024x1024
label:
en_US: Choose Image Size
zh_Hans: 选择生成图片的大小
form: form
- name: batch_size
type: number
required: true
default: 1
min: 1
max: 4
label:
en_US: Number Images
zh_Hans: 生成图片的数量
form: form
- name: guidance_scale
type: number
required: true
default: 7
min: 0
max: 100
label:
en_US: Guidance Scale
zh_Hans: 与提示词紧密性
human_description:
en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you.
zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。
form: form
- name: num_inference_steps
type: number
required: true
default: 20
min: 1
max: 100
label:
en_US: Num Inference Steps
zh_Hans: 生成图片的步数
human_description:
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
form: form
- name: seed
type: number
min: 0
max: 9999999999
label:
en_US: Seed
zh_Hans: 种子
human_description:
en_US: The same seed and prompt can produce similar images.
zh_Hans: 相同的种子和提示可以产生相似的图像。
form: form

37
api/poetry.lock generated
View File

@ -6143,6 +6143,19 @@ files = [
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
]
[package.dependencies]
@ -8854,6 +8867,28 @@ files = [
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
]
[[package]]
name = "volcengine-python-sdk"
version = "1.0.98"
description = "Volcengine SDK for Python"
optional = false
python-versions = "*"
files = [
{file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"},
]
[package.dependencies]
anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""}
certifi = ">=2017.4.17"
httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""}
pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""}
python-dateutil = ">=2.1"
six = ">=1.10"
urllib3 = ">=1.23"
[package.extras]
ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"]
[[package]]
name = "watchfiles"
version = "0.23.0"
@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243"

View File

@ -191,6 +191,7 @@ zhipuai = "1.0.7"
# Related transparent dependencies with pinned verion
# required by main implementations
############################################################
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"

View File

@ -34,16 +34,17 @@ const CardView: FC<ICardViewProps> = ({ appId }) => {
const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures)
const updateAppDetail = async () => {
fetchAppDetail({ url: '/apps', id: appId }).then((res) => {
try {
const res = await fetchAppDetail({ url: '/apps', id: appId })
if (systemFeatures.enable_web_sso_switch_component) {
fetchAppSSO({ appId }).then((ssoRes) => {
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
})
const ssoRes = await fetchAppSSO({ appId })
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
}
else {
setAppDetail({ ...res })
}
})
}
catch (error) { console.error(error) }
}
const handleCallbackResult = (err: Error | null, message?: string) => {

View File

@ -43,7 +43,7 @@ export type ConfigParams = {
icon: string
icon_background?: string
show_workflow_steps: boolean
enable_sso: boolean
enable_sso?: boolean
}
const prefixSettings = 'appOverview.overview.appInfo.settings'
@ -157,7 +157,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId,
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
show_workflow_steps: inputInfo.show_workflow_steps,
enable_sso: inputInfo.enable_sso!,
enable_sso: inputInfo.enable_sso,
}
await onSave?.(params)
setSaveLoading(false)
@ -235,7 +235,11 @@ const SettingsModal: FC<ISettingsModalProps> = ({
<p className='system-xs-medium text-gray-500'>{t(`${prefixSettings}.sso.label`)}</p>
<div className='flex justify-between items-center'>
<div className='font-medium system-sm-semibold flex-grow text-gray-900'>{t(`${prefixSettings}.sso.title`)}</div>
<Tooltip asChild={false} disabled={systemFeatures.sso_enforced_for_web} popupContent={<div className='w-[180px]'>{t(`${prefixSettings}.sso.tooltip`)}</div>}>
<Tooltip
disabled={systemFeatures.sso_enforced_for_web}
popupContent={<div className='w-[180px]'>{t(`${prefixSettings}.sso.tooltip`)}</div>}
asChild={false}
>
<Switch disabled={!systemFeatures.sso_enforced_for_web} defaultValue={systemFeatures.sso_enforced_for_web && inputInfo.enable_sso} onChange={v => setInputInfo({ ...inputInfo, enable_sso: v })}></Switch>
</Tooltip>
</div>

View File

@ -89,7 +89,7 @@ const Tooltip: FC<TooltipProps> = ({
onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)}
asChild={asChild}
>
{children || <div className={triggerClassName || 'p-[1px] w-3.5 h-3.5'}><RiQuestionLine className='text-text-quaternary hover:text-text-tertiary w-full h-full' /></div>}
{children || <div className={triggerClassName || 'p-[1px] w-3.5 h-3.5 shrink-0'}><RiQuestionLine className='text-text-quaternary hover:text-text-tertiary w-full h-full' /></div>}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent
className="z-[9999]"

View File

@ -1,6 +1,5 @@
import { Fragment, useState } from 'react'
import type { FC } from 'react'
import { ValidatingTip } from '../../key-validator/ValidateStatus'
import type {
CredentialFormSchema,

View File

@ -242,13 +242,14 @@ const ParameterItem: FC<ParameterItemProps> = ({
<div className='w-[200px] whitespace-pre-wrap'>{parameterRule.help[language] || parameterRule.help.en_US}</div>
)}
popupClassName='mr-1'
triggerClassName='mr-1 w-4 h-4'
triggerClassName='mr-1 w-4 h-4 shrink-0'
/>
)
}
{
!parameterRule.required && parameterRule.name !== 'stop' && (
<Switch
className='mr-1'
defaultValue={!isNullOrUndefined(value)}
onChange={handleSwitch}
size='md'

View File

@ -3,9 +3,8 @@ import type { FC } from 'react'
import React, { useCallback } from 'react'
import type { VariantProps } from 'class-variance-authority'
import { cva } from 'class-variance-authority'
import { RiQuestionLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import TooltipPlus from '@/app/components/base/tooltip'
import Tooltip from '@/app/components/base/tooltip'
const variants = cva([], {
variants: {
@ -59,15 +58,16 @@ const OptionCard: FC<Props> = ({
onClick={handleSelect}
>
<span>{title}</span>
{tooltip && <TooltipPlus
popupContent={<div className='w-[240px]'
>
{tooltip}
</div>}
asChild={false}
>
<RiQuestionLine className='ml-0.5 w-[14px] h-[14px] text-text-quaternary' />
</TooltipPlus>}
{tooltip
&& <Tooltip
popupContent={
<div className='w-[240px]'>
{tooltip}
</div>
}
asChild={false}
/>
}
</div>
)
}

View File

@ -54,7 +54,7 @@ const translation = {
chatColorThemeInverted: 'Inverted',
invalidHexMessage: 'Invalid hex value',
sso: {
label: 'SSO ENFORCEMENT',
label: 'SSO Authentication',
title: 'WebApp SSO',
description: 'All users are required to login with SSO before using WebApp',
tooltip: 'Contact the administrator to enable WebApp SSO',

View File

@ -16,7 +16,7 @@ export const fetchAppDetail = ({ url, id }: { url: string; id: string }) => {
export const fetchAppSSO = async ({ appId }: { appId: string }) => {
return get<AppSSOResponse>(`/enterprise/app-setting/sso?appID=${appId}`)
}
export const updateAppSSO = async ({ id, enabled }: { id: string;enabled: boolean }) => {
export const updateAppSSO = async ({ id, enabled }: { id: string; enabled: boolean }) => {
return post('/enterprise/app-setting/sso', { body: { app_id: id, enabled } })
}