From 0ebd9856725e273621db5c7943a9e18e01991aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E7=A8=8B?= Date: Mon, 28 Oct 2024 16:52:12 +0800 Subject: [PATCH] feat: add models for gitee.ai (#9490) --- .../gitee_ai/_assets/Gitee-AI-Logo-full.svg | 6 + .../gitee_ai/_assets/Gitee-AI-Logo.svg | 3 + .../model_providers/gitee_ai/_common.py | 47 +++++++ .../model_providers/gitee_ai/gitee_ai.py | 25 ++++ .../model_providers/gitee_ai/gitee_ai.yaml | 35 +++++ .../gitee_ai/llm/Qwen2-72B-Instruct.yaml | 105 ++++++++++++++ .../gitee_ai/llm/Qwen2-7B-Instruct.yaml | 105 ++++++++++++++ .../gitee_ai/llm/Yi-1.5-34B-Chat.yaml | 105 ++++++++++++++ .../gitee_ai/llm/_position.yaml | 7 + .../gitee_ai/llm/codegeex4-all-9b.yaml | 105 ++++++++++++++ .../llm/deepseek-coder-33B-instruct-chat.yaml | 105 ++++++++++++++ ...epseek-coder-33B-instruct-completions.yaml | 91 ++++++++++++ .../gitee_ai/llm/glm-4-9b-chat.yaml | 105 ++++++++++++++ .../model_providers/gitee_ai/llm/llm.py | 47 +++++++ .../gitee_ai/rerank/__init__.py | 0 .../gitee_ai/rerank/_position.yaml | 1 + .../gitee_ai/rerank/bge-reranker-v2-m3.yaml | 4 + .../model_providers/gitee_ai/rerank/rerank.py | 128 +++++++++++++++++ .../gitee_ai/speech2text/__init__.py | 0 .../gitee_ai/speech2text/_position.yaml | 2 + .../gitee_ai/speech2text/speech2text.py | 53 +++++++ .../gitee_ai/speech2text/whisper-base.yaml | 5 + .../gitee_ai/speech2text/whisper-large.yaml | 5 + .../gitee_ai/text_embedding/_position.yaml | 3 + .../text_embedding/bge-large-zh-v1.5.yaml | 8 ++ .../gitee_ai/text_embedding/bge-m3.yaml | 8 ++ .../text_embedding/bge-small-zh-v1.5.yaml | 8 ++ .../gitee_ai/text_embedding/text_embedding.py | 31 ++++ .../model_providers/gitee_ai/tts/ChatTTS.yaml | 11 ++ .../tts/FunAudioLLM-CosyVoice-300M.yaml | 11 ++ .../model_providers/gitee_ai/tts/__init__.py | 0 .../gitee_ai/tts/_position.yaml | 4 + .../gitee_ai/tts/fish-speech-1.2-sft.yaml | 11 ++ .../gitee_ai/tts/speecht5_tts.yaml | 11 ++ .../model_providers/gitee_ai/tts/tts.py | 79 +++++++++++ api/pytest.ini | 1 + api/tests/integration_tests/.env.example | 3 + .../model_runtime/gitee_ai/__init__.py | 0 .../model_runtime/gitee_ai/test_llm.py | 132 ++++++++++++++++++ .../model_runtime/gitee_ai/test_provider.py | 15 ++ .../model_runtime/gitee_ai/test_rerank.py | 47 +++++++ .../gitee_ai/test_speech2text.py | 45 ++++++ .../gitee_ai/test_text_embedding.py | 46 ++++++ .../model_runtime/gitee_ai/test_tts.py | 23 +++ 44 files changed, 1586 insertions(+) create mode 100644 api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg create mode 100644 api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg create mode 100644 api/core/model_runtime/model_providers/gitee_ai/_common.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml create mode 100644 api/core/model_runtime/model_providers/gitee_ai/tts/tts.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py create mode 100644 api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg new file mode 100644 index 0000000000..f9738b585b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg new file mode 100644 index 0000000000..1f51187f19 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py new file mode 100644 index 0000000000..0750f3b75d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -0,0 +1,47 @@ +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonGiteeAI: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py new file mode 100644 index 0000000000..ca67594ce4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GiteeAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml new file mode 100644 index 0000000000..7f7d0f2e53 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml @@ -0,0 +1,35 @@ +provider: gitee_ai +label: + en_US: Gitee AI + zh_Hans: Gitee AI +description: + en_US: 快速体验大模型,领先探索 AI 开源世界 + zh_Hans: 快速体验大模型,领先探索 AI 开源世界 +icon_small: + en_US: Gitee-AI-Logo.svg +icon_large: + en_US: Gitee-AI-Logo-full.svg +help: + title: + en_US: Get your token from Gitee AI + zh_Hans: 从 Gitee AI 获取 token + url: + en_US: https://ai.gitee.com/dashboard/settings/tokens +supported_model_types: + - llm + - text-embedding + - rerank + - speech2text + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml new file mode 100644 index 0000000000..0348438a75 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-72B-Instruct +label: + zh_Hans: Qwen2-72B-Instruct + en_US: Qwen2-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 6400 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml new file mode 100644 index 0000000000..ba1ad788f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-7B-Instruct +label: + zh_Hans: Qwen2-7B-Instruct + en_US: Qwen2-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml new file mode 100644 index 0000000000..f7260c987b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml @@ -0,0 +1,105 @@ +model: Yi-1.5-34B-Chat +label: + zh_Hans: Yi-1.5-34B-Chat + en_US: Yi-1.5-34B-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml new file mode 100644 index 0000000000..21f6120742 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml @@ -0,0 +1,7 @@ +- Qwen2-7B-Instruct +- Qwen2-72B-Instruct +- Yi-1.5-34B-Chat +- glm-4-9b-chat +- deepseek-coder-33B-instruct-chat +- deepseek-coder-33B-instruct-completions +- codegeex4-all-9b diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml new file mode 100644 index 0000000000..8632cd92ab --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml @@ -0,0 +1,105 @@ +model: codegeex4-all-9b +label: + zh_Hans: codegeex4-all-9b + en_US: codegeex4-all-9b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 40960 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml new file mode 100644 index 0000000000..2ac00761d5 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml @@ -0,0 +1,105 @@ +model: deepseek-coder-33B-instruct-chat +label: + zh_Hans: deepseek-coder-33B-instruct-chat + en_US: deepseek-coder-33B-instruct-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml new file mode 100644 index 0000000000..7c364d89f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml @@ -0,0 +1,91 @@ +model: deepseek-coder-33B-instruct-completions +label: + zh_Hans: deepseek-coder-33B-instruct-completions + en_US: deepseek-coder-33B-instruct-completions +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml new file mode 100644 index 0000000000..2afe1cf959 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml @@ -0,0 +1,105 @@ +model: glm-4-9b-chat +label: + zh_Hans: glm-4-9b-chat + en_US: glm-4-9b-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py new file mode 100644 index 0000000000..b65db6f665 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py @@ -0,0 +1,47 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + MODEL_TO_IDENTITY: dict[str, str] = { + "Yi-1.5-34B-Chat": "Yi-34B-Chat", + "deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct", + "deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials, model, model_parameters) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, model, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None: + if model is None: + model = "bge-large-zh-v1.5" + + model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model) + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/" + if model.endswith("completions"): + credentials["mode"] = LLMMode.COMPLETION.value + else: + credentials["mode"] = LLMMode.CHAT.value diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml new file mode 100644 index 0000000000..83162fd338 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml @@ -0,0 +1 @@ +- bge-reranker-v2-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 0000000000..f0681641e1 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 1024 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py new file mode 100644 index 0000000000..231345c2f4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -0,0 +1,128 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GiteeAIRerankModel(RerankModel): + """ + Model class for rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless") + base_url = base_url.removesuffix("/") + + try: + body = {"model": model, "query": query, "documents": docs} + if top_n is not None: + body["top_n"] = top_n + response = httpx.post( + f"{base_url}/{model}/rerank", + json=body, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, + ) + + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.01, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml new file mode 100644 index 0000000000..8e9b47598b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml @@ -0,0 +1,2 @@ +- whisper-base +- whisper-large diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py new file mode 100644 index 0000000000..5597f5b43e --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py @@ -0,0 +1,53 @@ +import os +from typing import IO, Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel): + """ + Model class for OpenAI Compatible Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text + + endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text" + files = [("file", file)] + _, file_ext = os.path.splitext(file.name) + headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"} + response = requests.post(endpoint_url, headers=headers, files=files) + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + response_data = response.json() + return response_data["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, "rb") as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml new file mode 100644 index 0000000000..a50bf5fc2d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml @@ -0,0 +1,5 @@ +model: whisper-base +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml new file mode 100644 index 0000000000..1be7b1a391 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml @@ -0,0 +1,5 @@ +model: whisper-large +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml new file mode 100644 index 0000000000..e8abe6440d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml @@ -0,0 +1,3 @@ +- bge-large-zh-v1.5 +- bge-small-zh-v1.5 +- bge-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml new file mode 100644 index 0000000000..9e3ca76e88 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-large-zh-v1.5 +label: + zh_Hans: bge-large-zh-v1.5 + en_US: bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml new file mode 100644 index 0000000000..a7a99a98a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml @@ -0,0 +1,8 @@ +model: bge-m3 +label: + zh_Hans: bge-m3 + en_US: bge-m3 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml new file mode 100644 index 0000000000..bd760408fa --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-small-zh-v1.5 +label: + zh_Hans: bge-small-zh-v1.5 + en_US: bge-small-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py new file mode 100644 index 0000000000..b833c5652c --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -0,0 +1,31 @@ +from typing import Optional + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GiteeAIEmbeddingModel(OAICompatEmbeddingModel): + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + self._add_custom_parameters(credentials, model) + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str) -> None: + if model is None: + model = "bge-m3" + + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml new file mode 100644 index 0000000000..940391dfab --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml @@ -0,0 +1,11 @@ +model: ChatTTS +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml new file mode 100644 index 0000000000..8fc5734801 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml @@ -0,0 +1,11 @@ +model: FunAudioLLM-CosyVoice-300M +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml new file mode 100644 index 0000000000..13c6ec8454 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml @@ -0,0 +1,4 @@ +- speecht5_tts +- ChatTTS +- fish-speech-1.2-sft +- FunAudioLLM-CosyVoice-300M diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml new file mode 100644 index 0000000000..93cc28bc9d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml @@ -0,0 +1,11 @@ +model: fish-speech-1.2-sft +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml new file mode 100644 index 0000000000..f9c843bd41 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml @@ -0,0 +1,11 @@ +model: speecht5_tts +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py new file mode 100644 index 0000000000..ed2bd5b13d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -0,0 +1,79 @@ +from typing import Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :return: text translated to audio file + """ + try: + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text="Hello Dify!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech + endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech" + + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = {"inputs": content_text} + response = requests.post(endpoint_url, headers=headers, json=payload) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + + data = response.content + + for i in range(0, len(data), 1024): + yield data[i : i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/pytest.ini b/api/pytest.ini index dcca08e2e5..a23a4b3f3d 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -27,3 +27,4 @@ env = XINFERENCE_GENERATION_MODEL_UID = generate XINFERENCE_RERANK_MODEL_UID = rerank XINFERENCE_SERVER_URL = http://a.abc.com:11451 + GITEE_AI_API_KEY = aaaaaaaaaaaaaaaaaaaa diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2d52399d29..6791cd891b 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -83,3 +83,6 @@ VOLC_EMBEDDING_ENDPOINT_ID= # 360 AI Credentials ZHINAO_API_KEY= + +# Gitee AI Credentials +GITEE_AI_API_KEY= diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py b/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py new file mode 100644 index 0000000000..753c52ce31 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py @@ -0,0 +1,132 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel + + +def test_predefined_models(): + model = GiteeAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + model = GiteeAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + "stream": False, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_stream_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + + +def test_get_num_tokens(): + model = GiteeAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py new file mode 100644 index 0000000000..f12ed54a45 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider + + +def test_validate_provider_credentials(): + provider = GiteeAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "invalid_key"}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py new file mode 100644 index 0000000000..0e5914a61f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel + + +def test_validate_credentials(): + model = GiteeAIRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GiteeAIRerankModel() + result = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + top_n=1, + score_threshold=0.01, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].score >= 0.01 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py new file mode 100644 index 0000000000..4a01453fdd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel + + +def test_validate_credentials(): + model = GiteeAISpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-base", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="whisper-base", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_model(): + model = GiteeAISpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file + ) + + assert isinstance(result, str) + assert result == "1 2 3 4 5 6 7 8 9 10" diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py new file mode 100644 index 0000000000..34648f0bc8 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py @@ -0,0 +1,46 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel + + +def test_validate_credentials(): + model = GiteeAIEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) + + +def test_invoke_model(): + model = GiteeAIEmbeddingModel() + + result = model.invoke( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + user="user", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + + +def test_get_num_tokens(): + model = GiteeAIEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py new file mode 100644 index 0000000000..9f18161a7b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py @@ -0,0 +1,23 @@ +import os + +from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel + + +def test_invoke_model(): + model = GiteeAIText2SpeechModel() + + result = model.invoke( + model="speecht5_tts", + tenant_id="test", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + content_text="Hello, world!", + voice="", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b""