diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py
index bea045f160..69689fe167 100644
--- a/api/core/app_runner/generate_task_pipeline.py
+++ b/api/core/app_runner/generate_task_pipeline.py
@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files": files
})
else:
- prompts.append({
+ prompt_message = prompt_messages[0]
+ text = ''
+ files = []
+ if isinstance(prompt_message.content, list):
+ for content in prompt_message.content:
+ if content.type == PromptMessageContentType.TEXT:
+ content = cast(TextPromptMessageContent, content)
+ text += content.data
+ else:
+ content = cast(ImagePromptMessageContent, content)
+ files.append({
+ "type": 'image',
+ "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
+ "detail": content.detail.value
+ })
+ else:
+ text = prompt_message.content
+
+ params = {
"role": 'user',
- "text": prompt_messages[0].content
- })
+ "text": text,
+ }
+
+ if files:
+ params['files'] = files
+
+ prompts.append(params)
return prompts
diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index bf1d7f2d42..360dabe6a4 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -6,6 +6,7 @@
- huggingface_hub
- cohere
- togetherai
+- ollama
- zhipuai
- baichuan
- spark
diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml
index 6cea787901..e4b625d171 100644
--- a/api/core/model_runtime/model_providers/localai/localai.yaml
+++ b/api/core/model_runtime/model_providers/localai/localai.yaml
@@ -54,5 +54,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
- zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx
- en_US: Enter the url of your LocalAI, for example https://example.com/xxx
+ zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
+ en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
diff --git a/api/core/model_runtime/model_providers/ollama/__init__.py b/api/core/model_runtime/model_providers/ollama/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..39d8a1ece6
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg
@@ -0,0 +1,15 @@
+
diff --git a/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..f8482a96b9
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg
@@ -0,0 +1,15 @@
+
diff --git a/api/core/model_runtime/model_providers/ollama/llm/__init__.py b/api/core/model_runtime/model_providers/ollama/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py
new file mode 100644
index 0000000000..083e60f35e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py
@@ -0,0 +1,615 @@
+import json
+import logging
+import re
+from decimal import Decimal
+from typing import Optional, Generator, Union, List, cast
+from urllib.parse import urljoin
+
+import requests
+
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
+ UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
+ TextPromptMessageContent, SystemPromptMessage
+from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
+ PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
+ ModelFeature
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
+ LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
+ InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaLargeLanguageModel(LargeLanguageModel):
+ """
+ Model class for Ollama large language model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ return self._generate(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:
+ """
+ # get model mode
+ model_mode = self.get_model_mode(model, credentials)
+
+ if model_mode == LLMMode.CHAT:
+ # chat model
+ return self._num_tokens_from_messages(prompt_messages)
+ else:
+ first_prompt_message = prompt_messages[0]
+ if isinstance(first_prompt_message.content, str):
+ text = first_prompt_message.content
+ else:
+ text = ''
+ for message_content in first_prompt_message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(TextPromptMessageContent, message_content)
+ text = message_content.data
+ break
+ return self._get_num_tokens_by_gpt2(text)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._generate(
+ model=model,
+ credentials=credentials,
+ prompt_messages=[UserPromptMessage(content="ping")],
+ model_parameters={
+ 'num_predict': 5
+ },
+ stream=False
+ )
+ except InvokeError as ex:
+ raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
+ except Exception as ex:
+ raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
+
+ def _generate(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
+ stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+ """
+ Invoke llm completion model
+
+ :param model: model name
+ :param credentials: credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+
+ endpoint_url = credentials['base_url']
+ if not endpoint_url.endswith('/'):
+ endpoint_url += '/'
+
+ # prepare the payload for a simple ping to the model
+ data = {
+ 'model': model,
+ 'stream': stream
+ }
+
+ if 'format' in model_parameters:
+ data['format'] = model_parameters['format']
+ del model_parameters['format']
+
+ data['options'] = model_parameters or {}
+
+ if stop:
+ data['stop'] = "\n".join(stop)
+
+ completion_type = LLMMode.value_of(credentials['mode'])
+
+ if completion_type is LLMMode.CHAT:
+ endpoint_url = urljoin(endpoint_url, 'api/chat')
+ data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
+ else:
+ endpoint_url = urljoin(endpoint_url, 'api/generate')
+ first_prompt_message = prompt_messages[0]
+ if isinstance(first_prompt_message, UserPromptMessage):
+ first_prompt_message = cast(UserPromptMessage, first_prompt_message)
+ if isinstance(first_prompt_message.content, str):
+ data['prompt'] = first_prompt_message.content
+ else:
+ text = ''
+ images = []
+ for message_content in first_prompt_message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(TextPromptMessageContent, message_content)
+ text = message_content.data
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(ImagePromptMessageContent, message_content)
+ image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ images.append(image_data)
+
+ data['prompt'] = text
+ data['images'] = images
+
+ # send a post request to validate the credentials
+ response = requests.post(
+ endpoint_url,
+ headers=headers,
+ json=data,
+ timeout=(10, 60),
+ stream=stream
+ )
+
+ response.encoding = "utf-8"
+ if response.status_code != 200:
+ raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
+
+ if stream:
+ return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
+
+ return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
+
+ def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
+ response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
+ """
+ Handle llm completion response
+
+ :param model: model name
+ :param credentials: model credentials
+ :param completion_type: completion type
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm result
+ """
+ response_json = response.json()
+
+ if completion_type is LLMMode.CHAT:
+ message = response_json.get('message', {})
+ response_content = message.get('content', '')
+ else:
+ response_content = response_json['response']
+
+ assistant_message = AssistantPromptMessage(content=response_content)
+
+ if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
+ # transform usage
+ prompt_tokens = response_json["prompt_eval_count"]
+ completion_tokens = response_json["eval_count"]
+ else:
+ # calculate num tokens
+ prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+ completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # transform response
+ result = LLMResult(
+ model=response_json["model"],
+ prompt_messages=prompt_messages,
+ message=assistant_message,
+ usage=usage,
+ )
+
+ return result
+
+ def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
+ response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
+ """
+ Handle llm completion stream response
+
+ :param model: model name
+ :param credentials: model credentials
+ :param completion_type: completion type
+ :param response: response
+ :param prompt_messages: prompt messages
+ :return: llm response chunk generator result
+ """
+ full_text = ''
+ chunk_index = 0
+
+ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
+ -> LLMResultChunk:
+ # calculate num tokens
+ prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+ completion_tokens = self._get_num_tokens_by_gpt2(full_text)
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ return LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=index,
+ message=message,
+ finish_reason=finish_reason,
+ usage=usage
+ )
+ )
+
+ for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
+ if not chunk:
+ continue
+
+ try:
+ chunk_json = json.loads(chunk)
+ # stream ended
+ except json.JSONDecodeError as e:
+ yield create_final_llm_result_chunk(
+ index=chunk_index,
+ message=AssistantPromptMessage(content=""),
+ finish_reason="Non-JSON encountered."
+ )
+
+ chunk_index += 1
+ break
+
+ if completion_type is LLMMode.CHAT:
+ if not chunk_json:
+ continue
+
+ if 'message' not in chunk_json:
+ text = ''
+ else:
+ text = chunk_json.get('message').get('content', '')
+ else:
+ if not chunk_json:
+ continue
+
+ # transform assistant message to prompt message
+ text = chunk_json['response']
+
+ assistant_prompt_message = AssistantPromptMessage(
+ content=text
+ )
+
+ full_text += text
+
+ if chunk_json['done']:
+ # calculate num tokens
+ if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
+ # transform usage
+ prompt_tokens = chunk_json["prompt_eval_count"]
+ completion_tokens = chunk_json["eval_count"]
+ else:
+ # calculate num tokens
+ prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+ completion_tokens = self._get_num_tokens_by_gpt2(full_text)
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ yield LLMResultChunk(
+ model=chunk_json['model'],
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=chunk_index,
+ message=assistant_prompt_message,
+ finish_reason='stop',
+ usage=usage
+ )
+ )
+ else:
+ yield LLMResultChunk(
+ model=chunk_json['model'],
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=chunk_index,
+ message=assistant_prompt_message,
+ )
+ )
+
+ chunk_index += 1
+
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+ """
+ Convert PromptMessage to dict for Ollama API
+ """
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": "user", "content": message.content}
+ else:
+ text = ''
+ images = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(TextPromptMessageContent, message_content)
+ text = message_content.data
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(ImagePromptMessageContent, message_content)
+ image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+ images.append(image_data)
+
+ message_dict = {"role": "user", "content": text, "images": images}
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {"role": "assistant", "content": message.content}
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {"role": "system", "content": message.content}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ return message_dict
+
+ def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
+ """
+ Calculate num tokens.
+
+ :param messages: messages
+ """
+ num_tokens = 0
+ messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ for key, value in message.items():
+ num_tokens += self._get_num_tokens_by_gpt2(str(key))
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+ return num_tokens
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ Get customizable model schema.
+
+ :param model: model name
+ :param credentials: credentials
+
+ :return: model schema
+ """
+ extras = {}
+
+ if 'vision_support' in credentials and credentials['vision_support'] == 'true':
+ extras['features'] = [ModelFeature.VISION]
+
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ zh_Hans=model,
+ en_US=model
+ ),
+ model_type=ModelType.LLM,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.MODE: credentials.get('mode'),
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
+ },
+ parameter_rules=[
+ ParameterRule(
+ name=DefaultParameterName.TEMPERATURE.value,
+ use_template=DefaultParameterName.TEMPERATURE.value,
+ label=I18nObject(en_US="Temperature"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="The temperature of the model. "
+ "Increasing the temperature will make the model answer "
+ "more creatively. (Default: 0.8)"),
+ default=0.8,
+ min=0,
+ max=2
+ ),
+ ParameterRule(
+ name=DefaultParameterName.TOP_P.value,
+ use_template=DefaultParameterName.TOP_P.value,
+ label=I18nObject(en_US="Top P"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
+ "more diverse text, while a lower value (e.g., 0.5) will generate more "
+ "focused and conservative text. (Default: 0.9)"),
+ default=0.9,
+ min=0,
+ max=1
+ ),
+ ParameterRule(
+ name="top_k",
+ label=I18nObject(en_US="Top K"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Reduces the probability of generating nonsense. "
+ "A higher value (e.g. 100) will give more diverse answers, "
+ "while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
+ default=40,
+ min=1,
+ max=100
+ ),
+ ParameterRule(
+ name='repeat_penalty',
+ label=I18nObject(en_US="Repeat Penalty"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
+ "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
+ "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
+ default=1.1,
+ min=-2,
+ max=2
+ ),
+ ParameterRule(
+ name='num_predict',
+ use_template='max_tokens',
+ label=I18nObject(en_US="Num Predict"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
+ "(Default: 128, -1 = infinite generation, -2 = fill context)"),
+ default=128,
+ min=-2,
+ max=int(credentials.get('max_tokens', 4096)),
+ ),
+ ParameterRule(
+ name='mirostat',
+ label=I18nObject(en_US="Mirostat sampling"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
+ "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
+ default=0,
+ min=0,
+ max=2
+ ),
+ ParameterRule(
+ name='mirostat_eta',
+ label=I18nObject(en_US="Mirostat Eta"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
+ "the generated text. A lower learning rate will result in slower adjustments, "
+ "while a higher learning rate will make the algorithm more responsive. "
+ "(Default: 0.1)"),
+ default=0.1,
+ precision=1
+ ),
+ ParameterRule(
+ name='mirostat_tau',
+ label=I18nObject(en_US="Mirostat Tau"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
+ "A lower value will result in more focused and coherent text. (Default: 5.0)"),
+ default=5.0,
+ precision=1
+ ),
+ ParameterRule(
+ name='num_ctx',
+ label=I18nObject(en_US="Size of context window"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
+ "(Default: 2048)"),
+ default=2048,
+ min=1
+ ),
+ ParameterRule(
+ name='num_gpu',
+ label=I18nObject(en_US="Num GPU"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="The number of layers to send to the GPU(s). "
+ "On macOS it defaults to 1 to enable metal support, 0 to disable."),
+ default=1,
+ min=0,
+ max=1
+ ),
+ ParameterRule(
+ name='num_thread',
+ label=I18nObject(en_US="Num Thread"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Sets the number of threads to use during computation. "
+ "By default, Ollama will detect this for optimal performance. "
+ "It is recommended to set this value to the number of physical CPU cores "
+ "your system has (as opposed to the logical number of cores)."),
+ min=1,
+ ),
+ ParameterRule(
+ name='repeat_last_n',
+ label=I18nObject(en_US="Repeat last N"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
+ "(Default: 64, 0 = disabled, -1 = num_ctx)"),
+ default=64,
+ min=-1
+ ),
+ ParameterRule(
+ name='tfs_z',
+ label=I18nObject(en_US="TFS Z"),
+ type=ParameterType.FLOAT,
+ help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
+ "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
+ "while a value of 1.0 disables this setting. (default: 1)"),
+ default=1,
+ precision=1
+ ),
+ ParameterRule(
+ name='seed',
+ label=I18nObject(en_US="Seed"),
+ type=ParameterType.INT,
+ help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
+ "a specific number will make the model generate the same text for "
+ "the same prompt. (Default: 0)"),
+ default=0
+ ),
+ ParameterRule(
+ name='format',
+ label=I18nObject(en_US="Format"),
+ type=ParameterType.STRING,
+ help=I18nObject(en_US="the format to return a response in."
+ " Currently the only accepted value is json."),
+ options=['json'],
+ )
+ ],
+ pricing=PriceConfig(
+ input=Decimal(credentials.get('input_price', 0)),
+ output=Decimal(credentials.get('output_price', 0)),
+ unit=Decimal(credentials.get('unit', 0)),
+ currency=credentials.get('currency', "USD")
+ ),
+ **extras
+ )
+
+ return entity
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeAuthorizationError: [
+ requests.exceptions.InvalidHeader, # Missing or Invalid API Key
+ ],
+ InvokeBadRequestError: [
+ requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
+ requests.exceptions.InvalidURL, # Misconfigured request or other API error
+ ],
+ InvokeRateLimitError: [
+ requests.exceptions.RetryError # Too many requests sent in a short period of time
+ ],
+ InvokeServerUnavailableError: [
+ requests.exceptions.ConnectionError, # Engine Overloaded
+ requests.exceptions.HTTPError # Server Error
+ ],
+ InvokeConnectionError: [
+ requests.exceptions.ConnectTimeout, # Timeout
+ requests.exceptions.ReadTimeout # Timeout
+ ]
+ }
diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py
new file mode 100644
index 0000000000..f8a17b98a0
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/ollama.py
@@ -0,0 +1,17 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIProvider(ModelProvider):
+
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ pass
diff --git a/api/core/model_runtime/model_providers/ollama/ollama.yaml b/api/core/model_runtime/model_providers/ollama/ollama.yaml
new file mode 100644
index 0000000000..d796831461
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/ollama.yaml
@@ -0,0 +1,98 @@
+provider: ollama
+label:
+ en_US: Ollama
+icon_large:
+ en_US: icon_l_en.svg
+icon_small:
+ en_US: icon_s_en.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: How to integrate with Ollama
+ zh_Hans: 如何集成 Ollama
+ url:
+ en_US: https://docs.dify.ai/advanced/model-configuration/ollama
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: base_url
+ label:
+ zh_Hans: 基础 URL
+ en_US: Base URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
+ en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion mode
+ type: select
+ required: true
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select completion mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 在此输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: '4096'
+ type: text-input
+ required: true
+ - variable: vision_support
+ label:
+ zh_Hans: 是否支持 Vision
+ en_US: Vision support
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: 'false'
+ type: radio
+ required: false
+ options:
+ - value: 'true'
+ label:
+ en_US: Yes
+ zh_Hans: 是
+ - value: 'false'
+ label:
+ en_US: No
+ zh_Hans: 否
diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py b/api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..f496898180
--- /dev/null
+++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
@@ -0,0 +1,221 @@
+import logging
+import time
+from decimal import Decimal
+from typing import Optional
+from urllib.parse import urljoin
+import requests
+import json
+
+import numpy as np
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
+ PriceConfig
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
+ InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for an Ollama text embedding model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :return: embeddings result
+ """
+
+ # Prepare headers and payload for the request
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+
+ endpoint_url = credentials.get('base_url')
+ if not endpoint_url.endswith('/'):
+ endpoint_url += '/'
+
+ endpoint_url = urljoin(endpoint_url, 'api/embeddings')
+
+ # get model properties
+ context_size = self._get_context_size(model, credentials)
+
+ inputs = []
+ used_tokens = 0
+
+ for i, text in enumerate(texts):
+ # Here token count is only an approximation based on the GPT2 tokenizer
+ num_tokens = self._get_num_tokens_by_gpt2(text)
+
+ if num_tokens >= context_size:
+ cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
+ # if num tokens is larger than context length, only use the start
+ inputs.append(text[0: cutoff])
+ else:
+ inputs.append(text)
+
+ batched_embeddings = []
+
+ for text in inputs:
+ # Prepare the payload for the request
+ payload = {
+ 'prompt': text,
+ 'model': model,
+ }
+
+ # Make the request to the OpenAI API
+ response = requests.post(
+ endpoint_url,
+ headers=headers,
+ data=json.dumps(payload),
+ timeout=(10, 300)
+ )
+
+ response.raise_for_status() # Raise an exception for HTTP errors
+ response_data = response.json()
+
+ # Extract embeddings and used tokens from the response
+ embeddings = response_data['embedding']
+ embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
+
+ used_tokens += embedding_used_tokens
+ batched_embeddings.append(embeddings)
+
+ # calc usage
+ usage = self._calc_response_usage(
+ model=model,
+ credentials=credentials,
+ tokens=used_tokens
+ )
+
+ return TextEmbeddingResult(
+ embeddings=batched_embeddings,
+ usage=usage,
+ model=model
+ )
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Approximate number of tokens for given messages using GPT2 tokenizer
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(
+ model=model,
+ credentials=credentials,
+ texts=['ping']
+ )
+ except InvokeError as ex:
+ raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
+ except Exception as ex:
+ raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+ ModelPropertyKey.MAX_CHUNKS: 1,
+ },
+ parameter_rules=[],
+ pricing=PriceConfig(
+ input=Decimal(credentials.get('input_price', 0)),
+ unit=Decimal(credentials.get('unit', 0)),
+ currency=credentials.get('currency', "USD")
+ )
+ )
+
+ return entity
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeAuthorizationError: [
+ requests.exceptions.InvalidHeader, # Missing or Invalid API Key
+ ],
+ InvokeBadRequestError: [
+ requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
+ requests.exceptions.InvalidURL, # Misconfigured request or other API error
+ ],
+ InvokeRateLimitError: [
+ requests.exceptions.RetryError # Too many requests sent in a short period of time
+ ],
+ InvokeServerUnavailableError: [
+ requests.exceptions.ConnectionError, # Engine Overloaded
+ requests.exceptions.HTTPError # Server Error
+ ],
+ InvokeConnectionError: [
+ requests.exceptions.ConnectTimeout, # Timeout
+ requests.exceptions.ReadTimeout # Timeout
+ ]
+ }
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
index b92b6dce3f..acb974b050 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
+ break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
index 26925606b2..088738c0ff 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
- zh_Hans: Base URL, eg. https://api.openai.com/v1
- en_US: Base URL, eg. https://api.openai.com/v1
+ zh_Hans: Base URL, e.g. https://api.openai.com/v1
+ en_US: Base URL, e.g. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type
diff --git a/api/core/model_runtime/model_providers/openllm/openllm.yaml b/api/core/model_runtime/model_providers/openllm/openllm.yaml
index bd93baa727..fef52695e3 100644
--- a/api/core/model_runtime/model_providers/openllm/openllm.yaml
+++ b/api/core/model_runtime/model_providers/openllm/openllm.yaml
@@ -33,5 +33,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
- zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx
- en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
+ zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000
+ en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000
diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml
index f5391d0324..bb6c6d8668 100644
--- a/api/core/model_runtime/model_providers/xinference/xinference.yaml
+++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml
@@ -34,8 +34,8 @@ model_credential_schema:
type: secret-input
required: true
placeholder:
- zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
- en_US: Enter the url of your Xinference, for example https://example.com/xxx
+ zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
+ en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- variable: model_uid
label:
zh_Hans: 模型UID
diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py
index 28fea7c3ce..19c8e2d5ad 100644
--- a/api/core/prompt/prompt_transform.py
+++ b/api/core/prompt/prompt_transform.py
@@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
+ files=files,
context=context,
memory=memory,
model_config=model_config
@@ -343,7 +344,14 @@ class PromptTransform:
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
- prompt_message = UserPromptMessage(content=prompt)
+ if files:
+ prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+ for file in files:
+ prompt_message_contents.append(file.prompt_message_content)
+
+ prompt_message = UserPromptMessage(content=prompt_message_contents)
+ else:
+ prompt_message = UserPromptMessage(content=prompt)
return [prompt_message]
@@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
+ files: List[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> List[PromptMessage]:
@@ -461,7 +470,14 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
- prompt_messages.append(UserPromptMessage(content=prompt))
+ if files:
+ prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+ for file in files:
+ prompt_message_contents.append(file.prompt_message_content)
+
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+ else:
+ prompt_messages.append(UserPromptMessage(content=prompt))
return prompt_messages
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 89080b0788..04abacf73d 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
+# Ollama Credentials
+OLLAMA_BASE_URL=
+
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/ollama/__init__.py b/api/tests/integration_tests/model_runtime/ollama/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py
new file mode 100644
index 0000000000..5543085a54
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py
@@ -0,0 +1,260 @@
+import os
+from typing import Generator
+
+import pytest
+
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
+ SystemPromptMessage, TextPromptMessageContent, ImagePromptMessageContent
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
+ LLMResultChunk
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel
+
+
+def test_validate_credentials():
+ model = OllamaLargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='mistral:text',
+ credentials={
+ 'base_url': 'http://localhost:21434',
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ }
+ )
+
+ model.validate_credentials(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ }
+ )
+
+
+def test_invoke_model():
+ model = OllamaLargeLanguageModel()
+
+ response = model.invoke(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ 'num_predict': 10
+ },
+ stop=['How'],
+ stream=False
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+ model = OllamaLargeLanguageModel()
+
+ response = model.invoke(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ 'num_predict': 10
+ },
+ stop=['How'],
+ stream=True
+ )
+
+ assert isinstance(response, Generator)
+
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_invoke_completion_model():
+ model = OllamaLargeLanguageModel()
+
+ response = model.invoke(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'completion',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ 'num_predict': 10
+ },
+ stop=['How'],
+ stream=False
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+
+
+def test_invoke_stream_completion_model():
+ model = OllamaLargeLanguageModel()
+
+ response = model.invoke(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'completion',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ 'num_predict': 10
+ },
+ stop=['How'],
+ stream=True
+ )
+
+ assert isinstance(response, Generator)
+
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_invoke_completion_model_with_vision():
+ model = OllamaLargeLanguageModel()
+
+ result = model.invoke(
+ model='llava',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'completion',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content=[
+ TextPromptMessageContent(
+ data='What is this in this picture?',
+ ),
+ ImagePromptMessageContent(
+ data='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC'
+ )
+ ]
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.1,
+ 'num_predict': 100
+ },
+ stream=False,
+ )
+
+ assert isinstance(result, LLMResult)
+ assert len(result.message.content) > 0
+
+
+def test_invoke_chat_model_with_vision():
+ model = OllamaLargeLanguageModel()
+
+ result = model.invoke(
+ model='llava',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content=[
+ TextPromptMessageContent(
+ data='What is this in this picture?',
+ ),
+ ImagePromptMessageContent(
+ data='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC'
+ )
+ ]
+ )
+ ],
+ model_parameters={
+ 'temperature': 0.1,
+ 'num_predict': 100
+ },
+ stream=False,
+ )
+
+ assert isinstance(result, LLMResult)
+ assert len(result.message.content) > 0
+
+
+def test_get_num_tokens():
+ model = OllamaLargeLanguageModel()
+
+ num_tokens = model.get_num_tokens(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 2048,
+ 'max_tokens': 2048,
+ },
+ prompt_messages=[
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ]
+ )
+
+ assert isinstance(num_tokens, int)
+ assert num_tokens == 6
diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py
new file mode 100644
index 0000000000..c5f5918235
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py
@@ -0,0 +1,71 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
+
+
+def test_validate_credentials():
+ model = OllamaEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='mistral:text',
+ credentials={
+ 'base_url': 'http://localhost:21434',
+ 'mode': 'chat',
+ 'context_size': 4096,
+ }
+ )
+
+ model.validate_credentials(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 4096,
+ }
+ )
+
+
+def test_invoke_model():
+ model = OllamaEmbeddingModel()
+
+ result = model.invoke(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 4096,
+ },
+ texts=[
+ "hello",
+ "world"
+ ],
+ user="abc-123"
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens == 2
+
+
+def test_get_num_tokens():
+ model = OllamaEmbeddingModel()
+
+ num_tokens = model.get_num_tokens(
+ model='mistral:text',
+ credentials={
+ 'base_url': os.environ.get('OLLAMA_BASE_URL'),
+ 'mode': 'chat',
+ 'context_size': 4096,
+ },
+ texts=[
+ "hello",
+ "world"
+ ]
+ )
+
+ assert num_tokens == 2