From bf9349c4dc22d4cbfe76ec1db057cf5a53dd3aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 5 Nov 2024 14:42:47 +0800 Subject: [PATCH] feat: add xAI model provider (#10272) --- .../model_providers/x/__init__.py | 0 .../model_providers/x/_assets/x-ai-logo.svg | 1 + .../model_providers/x/llm/__init__.py | 0 .../model_providers/x/llm/grok-beta.yaml | 63 ++++++ .../model_providers/x/llm/llm.py | 37 ++++ api/core/model_runtime/model_providers/x/x.py | 25 +++ .../model_runtime/model_providers/x/x.yaml | 38 ++++ api/tests/integration_tests/.env.example | 4 + .../model_runtime/x/__init__.py | 0 .../model_runtime/x/test_llm.py | 204 ++++++++++++++++++ 10 files changed, 372 insertions(+) create mode 100644 api/core/model_runtime/model_providers/x/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg create mode 100644 api/core/model_runtime/model_providers/x/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/llm/grok-beta.yaml create mode 100644 api/core/model_runtime/model_providers/x/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/x/x.py create mode 100644 api/core/model_runtime/model_providers/x/x.yaml create mode 100644 api/tests/integration_tests/model_runtime/x/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/x/test_llm.py diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 0000000000..f8b745cb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 0000000000..7c305735b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 0000000000..3f5325a857 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,37 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 0000000000..e3f2b8eeba --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 0000000000..90d1cbfe7e --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai/v1 + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 99728a8271..6fd144c5c2 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -95,3 +95,7 @@ GPUSTACK_API_KEY= # Gitee AI Credentials GITEE_AI_API_KEY= + +# xAI Credentials +XAI_API_KEY= +XAI_API_BASE= diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/api/tests/integration_tests/model_runtime/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py new file mode 100644 index 0000000000..647a2f6480 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/x/test_llm.py @@ -0,0 +1,204 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = XAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials( + model="gpt-3.5-turbo", + credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, + ) + + model.validate_credentials( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = XAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77