diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index bea045f160..69689fe167 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -459,10 +459,33 @@ class GenerateTaskPipeline: "files": files }) else: - prompts.append({ + prompt_message = prompt_messages[0] + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + params = { "role": 'user', - "text": prompt_messages[0].content - }) + "text": text, + } + + if files: + params['files'] = files + + prompts.append(params) return prompts diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index bf1d7f2d42..360dabe6a4 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -6,6 +6,7 @@ - huggingface_hub - cohere - togetherai +- ollama - zhipuai - baichuan - spark diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index 6cea787901..e4b625d171 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -54,5 +54,5 @@ model_credential_schema: type: text-input required: true placeholder: - zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx - en_US: Enter the url of your LocalAI, for example https://example.com/xxx + zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 + en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 diff --git a/api/core/model_runtime/model_providers/ollama/__init__.py b/api/core/model_runtime/model_providers/ollama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg new file mode 100644 index 0000000000..39d8a1ece6 --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg new file mode 100644 index 0000000000..f8482a96b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/ollama/llm/__init__.py b/api/core/model_runtime/model_providers/ollama/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py new file mode 100644 index 0000000000..083e60f35e --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -0,0 +1,615 @@ +import json +import logging +import re +from decimal import Decimal +from typing import Optional, Generator, Union, List, cast +from urllib.parse import urljoin + +import requests + +from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \ + UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \ + TextPromptMessageContent, SystemPromptMessage +from core.model_runtime.entities.model_entities import I18nObject, ModelType, \ + PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \ + ModelFeature +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \ + LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \ + InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class OllamaLargeLanguageModel(LargeLanguageModel): + """ + Model class for Ollama large language model. + """ + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=user + ) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if model_mode == LLMMode.CHAT: + # chat model + return self._num_tokens_from_messages(prompt_messages) + else: + first_prompt_message = prompt_messages[0] + if isinstance(first_prompt_message.content, str): + text = first_prompt_message.content + else: + text = '' + for message_content in first_prompt_message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + text = message_content.data + break + return self._get_num_tokens_by_gpt2(text) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._generate( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + 'num_predict': 5 + }, + stream=False + ) + except InvokeError as ex: + raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + except Exception as ex: + raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm completion model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + headers = { + 'Content-Type': 'application/json' + } + + endpoint_url = credentials['base_url'] + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + # prepare the payload for a simple ping to the model + data = { + 'model': model, + 'stream': stream + } + + if 'format' in model_parameters: + data['format'] = model_parameters['format'] + del model_parameters['format'] + + data['options'] = model_parameters or {} + + if stop: + data['stop'] = "\n".join(stop) + + completion_type = LLMMode.value_of(credentials['mode']) + + if completion_type is LLMMode.CHAT: + endpoint_url = urljoin(endpoint_url, 'api/chat') + data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + else: + endpoint_url = urljoin(endpoint_url, 'api/generate') + first_prompt_message = prompt_messages[0] + if isinstance(first_prompt_message, UserPromptMessage): + first_prompt_message = cast(UserPromptMessage, first_prompt_message) + if isinstance(first_prompt_message.content, str): + data['prompt'] = first_prompt_message.content + else: + text = '' + images = [] + for message_content in first_prompt_message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + text = message_content.data + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + images.append(image_data) + + data['prompt'] = text + data['images'] = images + + # send a post request to validate the credentials + response = requests.post( + endpoint_url, + headers=headers, + json=data, + timeout=(10, 60), + stream=stream + ) + + response.encoding = "utf-8" + if response.status_code != 200: + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + + if stream: + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) + + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode, + response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm completion response + + :param model: model name + :param credentials: model credentials + :param completion_type: completion type + :param response: response + :param prompt_messages: prompt messages + :return: llm result + """ + response_json = response.json() + + if completion_type is LLMMode.CHAT: + message = response_json.get('message', {}) + response_content = message.get('content', '') + else: + response_content = response_json['response'] + + assistant_message = AssistantPromptMessage(content=response_content) + + if 'prompt_eval_count' in response_json and 'eval_count' in response_json: + # transform usage + prompt_tokens = response_json["prompt_eval_count"] + completion_tokens = response_json["eval_count"] + else: + # calculate num tokens + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) + completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=response_json["model"], + prompt_messages=prompt_messages, + message=assistant_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode, + response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm completion stream response + + :param model: model name + :param credentials: model credentials + :param completion_type: completion type + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + full_text = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) + completion_tokens = self._get_num_tokens_by_gpt2(full_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'): + if not chunk: + continue + + try: + chunk_json = json.loads(chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + + chunk_index += 1 + break + + if completion_type is LLMMode.CHAT: + if not chunk_json: + continue + + if 'message' not in chunk_json: + text = '' + else: + text = chunk_json.get('message').get('content', '') + else: + if not chunk_json: + continue + + # transform assistant message to prompt message + text = chunk_json['response'] + + assistant_prompt_message = AssistantPromptMessage( + content=text + ) + + full_text += text + + if chunk_json['done']: + # calculate num tokens + if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json: + # transform usage + prompt_tokens = chunk_json["prompt_eval_count"] + completion_tokens = chunk_json["eval_count"] + else: + # calculate num tokens + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) + completion_tokens = self._get_num_tokens_by_gpt2(full_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=chunk_json['model'], + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + finish_reason='stop', + usage=usage + ) + ) + else: + yield LLMResultChunk( + model=chunk_json['model'], + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for Ollama API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + text = '' + images = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + text = message_content.data + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + images.append(image_data) + + message_dict = {"role": "user", "content": text, "images": images} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + + def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int: + """ + Calculate num tokens. + + :param messages: messages + """ + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + Get customizable model schema. + + :param model: model name + :param credentials: credentials + + :return: model schema + """ + extras = {} + + if 'vision_support' in credentials and credentials['vision_support'] == 'true': + extras['features'] = [ModelFeature.VISION] + + entity = AIModelEntity( + model=model, + label=I18nObject( + zh_Hans=model, + en_US=model + ), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + use_template=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="The temperature of the model. " + "Increasing the temperature will make the model answer " + "more creatively. (Default: 0.8)"), + default=0.8, + min=0, + max=2 + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + use_template=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " + "more diverse text, while a lower value (e.g., 0.5) will generate more " + "focused and conservative text. (Default: 0.9)"), + default=0.9, + min=0, + max=1 + ), + ParameterRule( + name="top_k", + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + help=I18nObject(en_US="Reduces the probability of generating nonsense. " + "A higher value (e.g. 100) will give more diverse answers, " + "while a lower value (e.g. 10) will be more conservative. (Default: 40)"), + default=40, + min=1, + max=100 + ), + ParameterRule( + name='repeat_penalty', + label=I18nObject(en_US="Repeat Penalty"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="Sets how strongly to penalize repetitions. " + "A higher value (e.g., 1.5) will penalize repetitions more strongly, " + "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"), + default=1.1, + min=-2, + max=2 + ), + ParameterRule( + name='num_predict', + use_template='max_tokens', + label=I18nObject(en_US="Num Predict"), + type=ParameterType.INT, + help=I18nObject(en_US="Maximum number of tokens to predict when generating text. " + "(Default: 128, -1 = infinite generation, -2 = fill context)"), + default=128, + min=-2, + max=int(credentials.get('max_tokens', 4096)), + ), + ParameterRule( + name='mirostat', + label=I18nObject(en_US="Mirostat sampling"), + type=ParameterType.INT, + help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. " + "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"), + default=0, + min=0, + max=2 + ), + ParameterRule( + name='mirostat_eta', + label=I18nObject(en_US="Mirostat Eta"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from " + "the generated text. A lower learning rate will result in slower adjustments, " + "while a higher learning rate will make the algorithm more responsive. " + "(Default: 0.1)"), + default=0.1, + precision=1 + ), + ParameterRule( + name='mirostat_tau', + label=I18nObject(en_US="Mirostat Tau"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. " + "A lower value will result in more focused and coherent text. (Default: 5.0)"), + default=5.0, + precision=1 + ), + ParameterRule( + name='num_ctx', + label=I18nObject(en_US="Size of context window"), + type=ParameterType.INT, + help=I18nObject(en_US="Sets the size of the context window used to generate the next token. " + "(Default: 2048)"), + default=2048, + min=1 + ), + ParameterRule( + name='num_gpu', + label=I18nObject(en_US="Num GPU"), + type=ParameterType.INT, + help=I18nObject(en_US="The number of layers to send to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable."), + default=1, + min=0, + max=1 + ), + ParameterRule( + name='num_thread', + label=I18nObject(en_US="Num Thread"), + type=ParameterType.INT, + help=I18nObject(en_US="Sets the number of threads to use during computation. " + "By default, Ollama will detect this for optimal performance. " + "It is recommended to set this value to the number of physical CPU cores " + "your system has (as opposed to the logical number of cores)."), + min=1, + ), + ParameterRule( + name='repeat_last_n', + label=I18nObject(en_US="Repeat last N"), + type=ParameterType.INT, + help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. " + "(Default: 64, 0 = disabled, -1 = num_ctx)"), + default=64, + min=-1 + ), + ParameterRule( + name='tfs_z', + label=I18nObject(en_US="TFS Z"), + type=ParameterType.FLOAT, + help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens " + "from the output. A higher value (e.g., 2.0) will reduce the impact more, " + "while a value of 1.0 disables this setting. (default: 1)"), + default=1, + precision=1 + ), + ParameterRule( + name='seed', + label=I18nObject(en_US="Seed"), + type=ParameterType.INT, + help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to " + "a specific number will make the model generate the same text for " + "the same prompt. (Default: 0)"), + default=0 + ), + ParameterRule( + name='format', + label=I18nObject(en_US="Format"), + type=ParameterType.STRING, + help=I18nObject(en_US="the format to return a response in." + " Currently the only accepted value is json."), + options=['json'], + ) + ], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + output=Decimal(credentials.get('output_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ), + **extras + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeAuthorizationError: [ + requests.exceptions.InvalidHeader, # Missing or Invalid API Key + ], + InvokeBadRequestError: [ + requests.exceptions.HTTPError, # Invalid Endpoint URL or model name + requests.exceptions.InvalidURL, # Misconfigured request or other API error + ], + InvokeRateLimitError: [ + requests.exceptions.RetryError # Too many requests sent in a short period of time + ], + InvokeServerUnavailableError: [ + requests.exceptions.ConnectionError, # Engine Overloaded + requests.exceptions.HTTPError # Server Error + ], + InvokeConnectionError: [ + requests.exceptions.ConnectTimeout, # Timeout + requests.exceptions.ReadTimeout # Timeout + ] + } diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py new file mode 100644 index 0000000000..f8a17b98a0 --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/ollama/ollama.yaml b/api/core/model_runtime/model_providers/ollama/ollama.yaml new file mode 100644 index 0000000000..d796831461 --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/ollama.yaml @@ -0,0 +1,98 @@ +provider: ollama +label: + en_US: Ollama +icon_large: + en_US: icon_l_en.svg +icon_small: + en_US: icon_s_en.svg +background: "#F9FAFB" +help: + title: + en_US: How to integrate with Ollama + zh_Hans: 如何集成 Ollama + url: + en_US: https://docs.dify.ai/advanced/model-configuration/ollama +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: base_url + label: + zh_Hans: 基础 URL + en_US: Base URL + type: text-input + required: true + placeholder: + zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434 + en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434 + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: 模型类型 + en_US: Completion mode + type: select + required: true + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input + required: true + - variable: vision_support + label: + zh_Hans: 是否支持 Vision + en_US: Vision support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: Yes + zh_Hans: 是 + - value: 'false' + label: + en_US: No + zh_Hans: 否 diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py b/api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py new file mode 100644 index 0000000000..f496898180 --- /dev/null +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -0,0 +1,221 @@ +import logging +import time +from decimal import Decimal +from typing import Optional +from urllib.parse import urljoin +import requests +import json + +import numpy as np + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \ + PriceConfig +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \ + InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +logger = logging.getLogger(__name__) + + +class OllamaEmbeddingModel(TextEmbeddingModel): + """ + Model class for an Ollama text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + # Prepare headers and payload for the request + headers = { + 'Content-Type': 'application/json' + } + + endpoint_url = credentials.get('base_url') + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + endpoint_url = urljoin(endpoint_url, 'api/embeddings') + + # get model properties + context_size = self._get_context_size(model, credentials) + + inputs = [] + used_tokens = 0 + + for i, text in enumerate(texts): + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0: cutoff]) + else: + inputs.append(text) + + batched_embeddings = [] + + for text in inputs: + # Prepare the payload for the request + payload = { + 'prompt': text, + 'model': model, + } + + # Make the request to the OpenAI API + response = requests.post( + endpoint_url, + headers=headers, + data=json.dumps(payload), + timeout=(10, 300) + ) + + response.raise_for_status() # Raise an exception for HTTP errors + response_data = response.json() + + # Extract embeddings and used tokens from the response + embeddings = response_data['embedding'] + embedding_used_tokens = self.get_num_tokens(model, credentials, [text]) + + used_tokens += embedding_used_tokens + batched_embeddings.append(embeddings) + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=used_tokens + ) + + return TextEmbeddingResult( + embeddings=batched_embeddings, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Approximate number of tokens for given messages using GPT2 tokenizer + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + texts=['ping'] + ) + except InvokeError as ex: + raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + except Exception as ex: + raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeAuthorizationError: [ + requests.exceptions.InvalidHeader, # Missing or Invalid API Key + ], + InvokeBadRequestError: [ + requests.exceptions.HTTPError, # Invalid Endpoint URL or model name + requests.exceptions.InvalidURL, # Misconfigured request or other API error + ], + InvokeRateLimitError: [ + requests.exceptions.RetryError # Too many requests sent in a short period of time + ], + InvokeServerUnavailableError: [ + requests.exceptions.ConnectionError, # Engine Overloaded + requests.exceptions.HTTPError # Server Error + ], + InvokeConnectionError: [ + requests.exceptions.ConnectTimeout, # Timeout + requests.exceptions.ReadTimeout # Timeout + ] + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index b92b6dce3f..acb974b050 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered." ) + break if not chunk_json or len(chunk_json['choices']) == 0: continue diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 26925606b2..088738c0ff 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -33,8 +33,8 @@ model_credential_schema: type: text-input required: true placeholder: - zh_Hans: Base URL, eg. https://api.openai.com/v1 - en_US: Base URL, eg. https://api.openai.com/v1 + zh_Hans: Base URL, e.g. https://api.openai.com/v1 + en_US: Base URL, e.g. https://api.openai.com/v1 - variable: mode show_on: - variable: __model_type diff --git a/api/core/model_runtime/model_providers/openllm/openllm.yaml b/api/core/model_runtime/model_providers/openllm/openllm.yaml index bd93baa727..fef52695e3 100644 --- a/api/core/model_runtime/model_providers/openllm/openllm.yaml +++ b/api/core/model_runtime/model_providers/openllm/openllm.yaml @@ -33,5 +33,5 @@ model_credential_schema: type: text-input required: true placeholder: - zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx - en_US: Enter the url of your OpenLLM, for example https://example.com/xxx + zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000 + en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000 diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index f5391d0324..bb6c6d8668 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -34,8 +34,8 @@ model_credential_schema: type: secret-input required: true placeholder: - zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx - en_US: Enter the url of your Xinference, for example https://example.com/xxx + zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997 + en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997 - variable: model_uid label: zh_Hans: 模型UID diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 28fea7c3ce..19c8e2d5ad 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -121,6 +121,7 @@ class PromptTransform: prompt_template_entity=prompt_template_entity, inputs=inputs, query=query, + files=files, context=context, memory=memory, model_config=model_config @@ -343,7 +344,14 @@ class PromptTransform: prompt_message = UserPromptMessage(content=prompt_message_contents) else: - prompt_message = UserPromptMessage(content=prompt) + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_message = UserPromptMessage(content=prompt_message_contents) + else: + prompt_message = UserPromptMessage(content=prompt) return [prompt_message] @@ -434,6 +442,7 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, + files: List[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) -> List[PromptMessage]: @@ -461,7 +470,14 @@ class PromptTransform: prompt = self._format_prompt(prompt_template, prompt_inputs) - prompt_messages.append(UserPromptMessage(content=prompt)) + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) return prompt_messages diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 89080b0788..04abacf73d 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -62,5 +62,8 @@ COHERE_API_KEY= # Jina Credentials JINA_API_KEY= +# Ollama Credentials +OLLAMA_BASE_URL= + # Mock Switch MOCK_SWITCH=false \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/ollama/__init__.py b/api/tests/integration_tests/model_runtime/ollama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py new file mode 100644 index 0000000000..5543085a54 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -0,0 +1,260 @@ +import os +from typing import Generator + +import pytest + +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \ + SystemPromptMessage, TextPromptMessageContent, ImagePromptMessageContent +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ + LLMResultChunk +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel + + +def test_validate_credentials(): + model = OllamaLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='mistral:text', + credentials={ + 'base_url': 'http://localhost:21434', + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + } + ) + + model.validate_credentials( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + } + ) + + +def test_invoke_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + 'num_predict': 10 + }, + stop=['How'], + stream=False + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + 'num_predict': 10 + }, + stop=['How'], + stream=True + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_completion_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'completion', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + 'num_predict': 10 + }, + stop=['How'], + stream=False + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_completion_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'completion', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + 'num_predict': 10 + }, + stop=['How'], + stream=True + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_completion_model_with_vision(): + model = OllamaLargeLanguageModel() + + result = model.invoke( + model='llava', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'completion', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + UserPromptMessage( + content=[ + TextPromptMessageContent( + data='What is this in this picture?', + ), + ImagePromptMessageContent( + data='' + ) + ] + ) + ], + model_parameters={ + 'temperature': 0.1, + 'num_predict': 100 + }, + stream=False, + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_chat_model_with_vision(): + model = OllamaLargeLanguageModel() + + result = model.invoke( + model='llava', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + UserPromptMessage( + content=[ + TextPromptMessageContent( + data='What is this in this picture?', + ), + ImagePromptMessageContent( + data='' + ) + ] + ) + ], + model_parameters={ + 'temperature': 0.1, + 'num_predict': 100 + }, + stream=False, + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_get_num_tokens(): + model = OllamaLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 2048, + 'max_tokens': 2048, + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py new file mode 100644 index 0000000000..c5f5918235 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -0,0 +1,71 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel + + +def test_validate_credentials(): + model = OllamaEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='mistral:text', + credentials={ + 'base_url': 'http://localhost:21434', + 'mode': 'chat', + 'context_size': 4096, + } + ) + + model.validate_credentials( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 4096, + } + ) + + +def test_invoke_model(): + model = OllamaEmbeddingModel() + + result = model.invoke( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 4096, + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = OllamaEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='mistral:text', + credentials={ + 'base_url': os.environ.get('OLLAMA_BASE_URL'), + 'mode': 'chat', + 'context_size': 4096, + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 2