feat: ollama support (#2003)

This commit is contained in:
takatost 2024-01-12 12:29:13 +08:00 committed by GitHub
parent 5e75f7022f
commit cca9edc97a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1369 additions and 13 deletions

View File

@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files": files
})
else:
prompts.append({
prompt_message = prompt_messages[0]
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
params = {
"role": 'user',
"text": prompt_messages[0].content
})
"text": text,
}
if files:
params['files'] = files
prompts.append(params)
return prompts

View File

@ -6,6 +6,7 @@
- huggingface_hub
- cohere
- togetherai
- ollama
- zhipuai
- baichuan
- spark

View File

@ -54,5 +54,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入LocalAI的服务器地址如 https://example.com/xxx
en_US: Enter the url of your LocalAI, for example https://example.com/xxx
zh_Hans: 在此输入LocalAI的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 12 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -0,0 +1,615 @@
import json
import logging
import re
from decimal import Decimal
from typing import Optional, Generator, Union, List, cast
from urllib.parse import urljoin
import requests
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
TextPromptMessageContent, SystemPromptMessage
from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class OllamaLargeLanguageModel(LargeLanguageModel):
"""
Model class for Ollama large language model.
"""
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if model_mode == LLMMode.CHAT:
# chat model
return self._num_tokens_from_messages(prompt_messages)
else:
first_prompt_message = prompt_messages[0]
if isinstance(first_prompt_message.content, str):
text = first_prompt_message.content
else:
text = ''
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
break
return self._get_num_tokens_by_gpt2(text)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'num_predict': 5
},
stream=False
)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
headers = {
'Content-Type': 'application/json'
}
endpoint_url = credentials['base_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model
data = {
'model': model,
'stream': stream
}
if 'format' in model_parameters:
data['format'] = model_parameters['format']
del model_parameters['format']
data['options'] = model_parameters or {}
if stop:
data['stop'] = "\n".join(stop)
completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'api/chat')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
else:
endpoint_url = urljoin(endpoint_url, 'api/generate')
first_prompt_message = prompt_messages[0]
if isinstance(first_prompt_message, UserPromptMessage):
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
if isinstance(first_prompt_message.content, str):
data['prompt'] = first_prompt_message.content
else:
text = ''
images = []
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
images.append(image_data)
data['prompt'] = text
data['images'] = images
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
headers=headers,
json=data,
timeout=(10, 60),
stream=stream
)
response.encoding = "utf-8"
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm completion response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm result
"""
response_json = response.json()
if completion_type is LLMMode.CHAT:
message = response_json.get('message', {})
response_content = message.get('content', '')
else:
response_content = response_json['response']
assistant_message = AssistantPromptMessage(content=response_content)
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
# transform usage
prompt_tokens = response_json["prompt_eval_count"]
completion_tokens = response_json["eval_count"]
else:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response_json["model"],
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage,
)
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm completion stream response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
full_text = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
if not chunk:
continue
try:
chunk_json = json.loads(chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
chunk_index += 1
break
if completion_type is LLMMode.CHAT:
if not chunk_json:
continue
if 'message' not in chunk_json:
text = ''
else:
text = chunk_json.get('message').get('content', '')
else:
if not chunk_json:
continue
# transform assistant message to prompt message
text = chunk_json['response']
assistant_prompt_message = AssistantPromptMessage(
content=text
)
full_text += text
if chunk_json['done']:
# calculate num tokens
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
# transform usage
prompt_tokens = chunk_json["prompt_eval_count"]
completion_tokens = chunk_json["eval_count"]
else:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk_json['model'],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
finish_reason='stop',
usage=usage
)
)
else:
yield LLMResultChunk(
model=chunk_json['model'],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Ollama API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
text = ''
images = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
images.append(image_data)
message_dict = {"role": "user", "content": text, "images": images}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
"""
Calculate num tokens.
:param messages: messages
"""
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
for key, value in message.items():
num_tokens += self._get_num_tokens_by_gpt2(str(key))
num_tokens += self._get_num_tokens_by_gpt2(str(value))
return num_tokens
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
:param model: model name
:param credentials: credentials
:return: model schema
"""
extras = {}
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
extras['features'] = [ModelFeature.VISION]
entity = AIModelEntity(
model=model,
label=I18nObject(
zh_Hans=model,
en_US=model
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: credentials.get('mode'),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
use_template=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="The temperature of the model. "
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"),
default=0.8,
min=0,
max=2
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
use_template=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"),
default=0.9,
min=0,
max=1
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
"A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
default=40,
min=1,
max=100
),
ParameterRule(
name='repeat_penalty',
label=I18nObject(en_US="Repeat Penalty"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
default=1.1,
min=-2,
max=2
),
ParameterRule(
name='num_predict',
use_template='max_tokens',
label=I18nObject(en_US="Num Predict"),
type=ParameterType.INT,
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
default=128,
min=-2,
max=int(credentials.get('max_tokens', 4096)),
),
ParameterRule(
name='mirostat',
label=I18nObject(en_US="Mirostat sampling"),
type=ParameterType.INT,
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
default=0,
min=0,
max=2
),
ParameterRule(
name='mirostat_eta',
label=I18nObject(en_US="Mirostat Eta"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
"the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)"),
default=0.1,
precision=1
),
ParameterRule(
name='mirostat_tau',
label=I18nObject(en_US="Mirostat Tau"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
default=5.0,
precision=1
),
ParameterRule(
name='num_ctx',
label=I18nObject(en_US="Size of context window"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
"(Default: 2048)"),
default=2048,
min=1
),
ParameterRule(
name='num_gpu',
label=I18nObject(en_US="Num GPU"),
type=ParameterType.INT,
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
default=1,
min=0,
max=1
),
ParameterRule(
name='num_thread',
label=I18nObject(en_US="Num Thread"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the number of threads to use during computation. "
"By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)."),
min=1,
),
ParameterRule(
name='repeat_last_n',
label=I18nObject(en_US="Repeat last N"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
default=64,
min=-1
),
ParameterRule(
name='tfs_z',
label=I18nObject(en_US="TFS Z"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)"),
default=1,
precision=1
),
ParameterRule(
name='seed',
label=I18nObject(en_US="Seed"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
"a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)"),
default=0
),
ParameterRule(
name='format',
label=I18nObject(en_US="Format"),
type=ParameterType.STRING,
help=I18nObject(en_US="the format to return a response in."
" Currently the only accepted value is json."),
options=['json'],
)
],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
),
**extras
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeAuthorizationError: [
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
],
InvokeBadRequestError: [
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
requests.exceptions.InvalidURL, # Misconfigured request or other API error
],
InvokeRateLimitError: [
requests.exceptions.RetryError # Too many requests sent in a short period of time
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
requests.exceptions.HTTPError # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}

View File

@ -0,0 +1,17 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

View File

@ -0,0 +1,98 @@
provider: ollama
label:
en_US: Ollama
icon_large:
en_US: icon_l_en.svg
icon_small:
en_US: icon_s_en.svg
background: "#F9FAFB"
help:
title:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
en_US: https://docs.dify.ai/advanced/model-configuration/ollama
supported_model_types:
- llm
- text-embedding
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: base_url
label:
zh_Hans: 基础 URL
en_US: Base URL
type: text-input
required: true
placeholder:
zh_Hans: Ollama server 的基础 URL例如 http://192.168.1.100:11434
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: 模型类型
en_US: Completion mode
type: select
required: true
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: '4096'
type: text-input
required: true
- variable: vision_support
label:
zh_Hans: 是否支持 Vision
en_US: Vision support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: Yes
zh_Hans:
- value: 'false'
label:
en_US: No
zh_Hans:

View File

@ -0,0 +1,221 @@
import logging
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
import numpy as np
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
PriceConfig
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
logger = logging.getLogger(__name__)
class OllamaEmbeddingModel(TextEmbeddingModel):
"""
Model class for an Ollama text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers = {
'Content-Type': 'application/json'
}
endpoint_url = credentials.get('base_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = urljoin(endpoint_url, 'api/embeddings')
# get model properties
context_size = self._get_context_size(model, credentials)
inputs = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
batched_embeddings = []
for text in inputs:
# Prepare the payload for the request
payload = {
'prompt': text,
'model': model,
}
# Make the request to the OpenAI API
response = requests.post(
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
)
response.raise_for_status() # Raise an exception for HTTP errors
response_data = response.json()
# Extract embeddings and used tokens from the response
embeddings = response_data['embedding']
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
used_tokens += embedding_used_tokens
batched_embeddings.append(embeddings)
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeAuthorizationError: [
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
],
InvokeBadRequestError: [
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
requests.exceptions.InvalidURL, # Misconfigured request or other API error
],
InvokeRateLimitError: [
requests.exceptions.RetryError # Too many requests sent in a short period of time
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
requests.exceptions.HTTPError # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}

View File

@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue

View File

@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
zh_Hans: Base URL, e.g. https://api.openai.com/v1
en_US: Base URL, e.g. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type

View File

@ -33,5 +33,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入OpenLLM的服务器地址如 https://example.com/xxx
en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
zh_Hans: 在此输入OpenLLM的服务器地址如 http://192.168.1.100:3000
en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000

View File

@ -34,8 +34,8 @@ model_credential_schema:
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
zh_Hans: 在此输入Xinference的服务器地址如 http://192.168.1.100:9997
en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- variable: model_uid
label:
zh_Hans: 模型UID

View File

@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
@ -343,7 +344,14 @@ class PromptTransform:
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
return [prompt_message]
@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: List[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> List[PromptMessage]:
@ -461,7 +470,14 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(UserPromptMessage(content=prompt))
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
return prompt_messages

View File

@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,71 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
def test_validate_credentials():
model = OllamaEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': 'http://localhost:21434',
'mode': 'chat',
'context_size': 4096,
}
)
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
}
)
def test_invoke_model():
model = OllamaEmbeddingModel()
result = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = OllamaEmbeddingModel()
num_tokens = model.get_num_tokens(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2