mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: ollama support (#2003)
This commit is contained in:
parent
5e75f7022f
commit
cca9edc97a
|
@ -459,10 +459,33 @@ class GenerateTaskPipeline:
|
|||
"files": files
|
||||
})
|
||||
else:
|
||||
prompts.append({
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
params = {
|
||||
"role": 'user',
|
||||
"text": prompt_messages[0].content
|
||||
})
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params['files'] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
return prompts
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
- huggingface_hub
|
||||
- cohere
|
||||
- togetherai
|
||||
- ollama
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
|
|
|
@ -54,5 +54,5 @@ model_credential_schema:
|
|||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your LocalAI, for example https://example.com/xxx
|
||||
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
|
||||
|
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 12 KiB |
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 7.7 KiB |
615
api/core/model_runtime/model_providers/ollama/llm/llm.py
Normal file
615
api/core/model_runtime/model_providers/ollama/llm/llm.py
Normal file
|
@ -0,0 +1,615 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Generator, Union, List, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
|
||||
UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
|
||||
TextPromptMessageContent, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
|
||||
PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
|
||||
ModelFeature
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
|
||||
LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
|
||||
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Ollama large language model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if model_mode == LLMMode.CHAT:
|
||||
# chat model
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
else:
|
||||
first_prompt_message = prompt_messages[0]
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
text = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
break
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
'num_predict': 5
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
endpoint_url = credentials['base_url']
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
'model': model,
|
||||
'stream': stream
|
||||
}
|
||||
|
||||
if 'format' in model_parameters:
|
||||
data['format'] = model_parameters['format']
|
||||
del model_parameters['format']
|
||||
|
||||
data['options'] = model_parameters or {}
|
||||
|
||||
if stop:
|
||||
data['stop'] = "\n".join(stop)
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/chat')
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/generate')
|
||||
first_prompt_message = prompt_messages[0]
|
||||
if isinstance(first_prompt_message, UserPromptMessage):
|
||||
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
data['prompt'] = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
images = []
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
images.append(image_data)
|
||||
|
||||
data['prompt'] = text
|
||||
data['images'] = images
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
response.encoding = "utf-8"
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param completion_type: completion type
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm result
|
||||
"""
|
||||
response_json = response.json()
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get('message', {})
|
||||
response_content = message.get('content', '')
|
||||
else:
|
||||
response_content = response_json['response']
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
|
||||
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
|
||||
# transform usage
|
||||
prompt_tokens = response_json["prompt_eval_count"]
|
||||
completion_tokens = response_json["eval_count"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm completion stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param completion_type: completion type
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
full_text = ''
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json = json.loads(chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
break
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
if 'message' not in chunk_json:
|
||||
text = ''
|
||||
else:
|
||||
text = chunk_json.get('message').get('content', '')
|
||||
else:
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = chunk_json['response']
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
|
||||
full_text += text
|
||||
|
||||
if chunk_json['done']:
|
||||
# calculate num tokens
|
||||
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
|
||||
# transform usage
|
||||
prompt_tokens = chunk_json["prompt_eval_count"]
|
||||
completion_tokens = chunk_json["eval_count"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason='stop',
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Ollama API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
text = ''
|
||||
images = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
images.append(image_data)
|
||||
|
||||
message_dict = {"role": "user", "content": text, "images": images}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
Calculate num tokens.
|
||||
|
||||
:param messages: messages
|
||||
"""
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
for key, value in message.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
|
||||
:return: model schema
|
||||
"""
|
||||
extras = {}
|
||||
|
||||
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
|
||||
extras['features'] = [ModelFeature.VISION]
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
zh_Hans=model,
|
||||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"),
|
||||
default=0.8,
|
||||
min=0,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
use_template=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"),
|
||||
default=0.9,
|
||||
min=0,
|
||||
max=1
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
||||
default=40,
|
||||
min=1,
|
||||
max=100
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_penalty',
|
||||
label=I18nObject(en_US="Repeat Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
||||
default=1.1,
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_predict',
|
||||
use_template='max_tokens',
|
||||
label=I18nObject(en_US="Num Predict"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
||||
default=128,
|
||||
min=-2,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat',
|
||||
label=I18nObject(en_US="Mirostat sampling"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
||||
default=0,
|
||||
min=0,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_eta',
|
||||
label=I18nObject(en_US="Mirostat Eta"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"),
|
||||
default=0.1,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_tau',
|
||||
label=I18nObject(en_US="Mirostat Tau"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
||||
default=5.0,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_ctx',
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"),
|
||||
default=2048,
|
||||
min=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_gpu',
|
||||
label=I18nObject(en_US="Num GPU"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
|
||||
default=1,
|
||||
min=0,
|
||||
max=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_thread',
|
||||
label=I18nObject(en_US="Num Thread"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."),
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_last_n',
|
||||
label=I18nObject(en_US="Repeat last N"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
||||
default=64,
|
||||
min=-1
|
||||
),
|
||||
ParameterRule(
|
||||
name='tfs_z',
|
||||
label=I18nObject(en_US="TFS Z"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
||||
default=1,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(en_US="Seed"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"),
|
||||
default=0
|
||||
),
|
||||
ParameterRule(
|
||||
name='format',
|
||||
label=I18nObject(en_US="Format"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."),
|
||||
options=['json'],
|
||||
)
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
),
|
||||
**extras
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeAuthorizationError: [
|
||||
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
|
||||
requests.exceptions.InvalidURL, # Misconfigured request or other API error
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
requests.exceptions.RetryError # Too many requests sent in a short period of time
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||
requests.exceptions.HTTPError # Server Error
|
||||
],
|
||||
InvokeConnectionError: [
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout # Timeout
|
||||
]
|
||||
}
|
17
api/core/model_runtime/model_providers/ollama/ollama.py
Normal file
17
api/core/model_runtime/model_providers/ollama/ollama.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
98
api/core/model_runtime/model_providers/ollama/ollama.yaml
Normal file
98
api/core/model_runtime/model_providers/ollama/ollama.yaml
Normal file
|
@ -0,0 +1,98 @@
|
|||
provider: ollama
|
||||
label:
|
||||
en_US: Ollama
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
background: "#F9FAFB"
|
||||
help:
|
||||
title:
|
||||
en_US: How to integrate with Ollama
|
||||
zh_Hans: 如何集成 Ollama
|
||||
url:
|
||||
en_US: https://docs.dify.ai/advanced/model-configuration/ollama
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: 基础 URL
|
||||
en_US: Base URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
|
||||
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: 模型类型
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: true
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
required: true
|
||||
- variable: vision_support
|
||||
label:
|
||||
zh_Hans: 是否支持 Vision
|
||||
en_US: Vision support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: Yes
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: No
|
||||
zh_Hans: 否
|
|
@ -0,0 +1,221 @@
|
|||
import logging
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
|
||||
PriceConfig
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
|
||||
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an Ollama text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
endpoint_url = credentials.get('base_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'api/embeddings')
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
|
||||
inputs = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
|
||||
batched_embeddings = []
|
||||
|
||||
for text in inputs:
|
||||
# Prepare the payload for the request
|
||||
payload = {
|
||||
'prompt': text,
|
||||
'model': model,
|
||||
}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings = response_data['embedding']
|
||||
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings.append(embeddings)
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeAuthorizationError: [
|
||||
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
|
||||
requests.exceptions.InvalidURL, # Misconfigured request or other API error
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
requests.exceptions.RetryError # Too many requests sent in a short period of time
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||
requests.exceptions.HTTPError # Server Error
|
||||
],
|
||||
InvokeConnectionError: [
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout # Timeout
|
||||
]
|
||||
}
|
|
@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
|
|
@ -33,8 +33,8 @@ model_credential_schema:
|
|||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Base URL, eg. https://api.openai.com/v1
|
||||
en_US: Base URL, eg. https://api.openai.com/v1
|
||||
zh_Hans: Base URL, e.g. https://api.openai.com/v1
|
||||
en_US: Base URL, e.g. https://api.openai.com/v1
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
|
|
|
@ -33,5 +33,5 @@ model_credential_schema:
|
|||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
|
||||
zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000
|
||||
en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000
|
||||
|
|
|
@ -34,8 +34,8 @@ model_credential_schema:
|
|||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
|
||||
en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
|
||||
- variable: model_uid
|
||||
label:
|
||||
zh_Hans: 模型UID
|
||||
|
|
|
@ -121,6 +121,7 @@ class PromptTransform:
|
|||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
|
@ -343,7 +344,14 @@ class PromptTransform:
|
|||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
prompt_message = UserPromptMessage(content=prompt)
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
prompt_message = UserPromptMessage(content=prompt)
|
||||
|
||||
return [prompt_message]
|
||||
|
||||
|
@ -434,6 +442,7 @@ class PromptTransform:
|
|||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[FileObj],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigEntity) -> List[PromptMessage]:
|
||||
|
@ -461,7 +470,14 @@ class PromptTransform:
|
|||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
|
|
@ -62,5 +62,8 @@ COHERE_API_KEY=
|
|||
# Jina Credentials
|
||||
JINA_API_KEY=
|
||||
|
||||
# Ollama Credentials
|
||||
OLLAMA_BASE_URL=
|
||||
|
||||
# Mock Switch
|
||||
MOCK_SWITCH=false
|
260
api/tests/integration_tests/model_runtime/ollama/test_llm.py
Normal file
260
api/tests/integration_tests/model_runtime/ollama/test_llm.py
Normal file
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,71 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OllamaEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistral:text',
|
||||
credentials={
|
||||
'base_url': 'http://localhost:21434',
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistral:text',
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OllamaEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='mistral:text',
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OllamaEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistral:text',
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
Loading…
Reference in New Issue
Block a user