mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Add OCI(Oracle Cloud Infrastructure) Generative AI Service as a Model Provider (#7775)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Co-authored-by: Walter Jin <jinshuhaicc@gmail.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: walter from vm <walter.jin@oracle.com>
This commit is contained in:
parent
e0d3cd91c6
commit
89aede80cc
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1,52 @@
|
||||||
|
model: cohere.command-r-16k
|
||||||
|
label:
|
||||||
|
en_US: cohere.command-r-16k v1.2
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
default: 1
|
||||||
|
max: 1.0
|
||||||
|
- name: topP
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.75
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
- name: topK
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
min: 0
|
||||||
|
max: 500
|
||||||
|
- name: presencePenalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
default: 0
|
||||||
|
- name: frequencyPenalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
default: 0
|
||||||
|
- name: maxTokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 600
|
||||||
|
max: 4000
|
||||||
|
pricing:
|
||||||
|
input: '0.004'
|
||||||
|
output: '0.004'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,52 @@
|
||||||
|
model: cohere.command-r-plus
|
||||||
|
label:
|
||||||
|
en_US: cohere.command-r-plus v1.2
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
default: 1
|
||||||
|
max: 1.0
|
||||||
|
- name: topP
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.75
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
- name: topK
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
min: 0
|
||||||
|
max: 500
|
||||||
|
- name: presencePenalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
default: 0
|
||||||
|
- name: frequencyPenalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
default: 0
|
||||||
|
- name: maxTokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 600
|
||||||
|
max: 4000
|
||||||
|
pricing:
|
||||||
|
input: '0.0219'
|
||||||
|
output: '0.0219'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
461
api/core/model_runtime/model_providers/oci/llm/llm.py
Normal file
461
api/core/model_runtime/model_providers/oci/llm/llm.py
Normal file
|
@ -0,0 +1,461 @@
|
||||||
|
import base64
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import oci
|
||||||
|
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
request_template = {
|
||||||
|
"compartmentId": "",
|
||||||
|
"servingMode": {
|
||||||
|
"modelId": "cohere.command-r-plus",
|
||||||
|
"servingType": "ON_DEMAND"
|
||||||
|
},
|
||||||
|
"chatRequest": {
|
||||||
|
"apiFormat": "COHERE",
|
||||||
|
#"preambleOverride": "You are a helpful assistant.",
|
||||||
|
#"message": "Hello!",
|
||||||
|
#"chatHistory": [],
|
||||||
|
"maxTokens": 600,
|
||||||
|
"isStream": False,
|
||||||
|
"frequencyPenalty": 0,
|
||||||
|
"presencePenalty": 0,
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.75
|
||||||
|
}
|
||||||
|
}
|
||||||
|
oci_config_template = {
|
||||||
|
"user": "",
|
||||||
|
"fingerprint": "",
|
||||||
|
"tenancy": "",
|
||||||
|
"region": "",
|
||||||
|
"compartment_id": "",
|
||||||
|
"key_content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
class OCILargeLanguageModel(LargeLanguageModel):
|
||||||
|
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||||
|
_supported_models = {
|
||||||
|
"meta.llama-3-70b-instruct": {
|
||||||
|
"system": True,
|
||||||
|
"multimodal": False,
|
||||||
|
"tool_call": False,
|
||||||
|
"stream_tool_call": False,
|
||||||
|
},
|
||||||
|
"cohere.command-r-16k": {
|
||||||
|
"system": True,
|
||||||
|
"multimodal": False,
|
||||||
|
"tool_call": True,
|
||||||
|
"stream_tool_call": False,
|
||||||
|
},
|
||||||
|
"cohere.command-r-plus": {
|
||||||
|
"system": True,
|
||||||
|
"multimodal": False,
|
||||||
|
"tool_call": True,
|
||||||
|
"stream_tool_call": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
|
||||||
|
feature = self._supported_models.get(model_id)
|
||||||
|
if not feature:
|
||||||
|
return False
|
||||||
|
return feature["stream_tool_call"] if stream else feature["tool_call"]
|
||||||
|
|
||||||
|
def _is_multimodal_supported(self, model_id: str) -> bool:
|
||||||
|
feature = self._supported_models.get(model_id)
|
||||||
|
if not feature:
|
||||||
|
return False
|
||||||
|
return feature["multimodal"]
|
||||||
|
|
||||||
|
def _is_system_prompt_supported(self, model_id: str) -> bool:
|
||||||
|
feature = self._supported_models.get(model_id)
|
||||||
|
if not feature:
|
||||||
|
return False
|
||||||
|
return feature["system"]
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
#print("model"+"*"*20)
|
||||||
|
#print(model)
|
||||||
|
#print("credentials"+"*"*20)
|
||||||
|
#print(credentials)
|
||||||
|
#print("model_parameters"+"*"*20)
|
||||||
|
#print(model_parameters)
|
||||||
|
#print("prompt_messages"+"*"*200)
|
||||||
|
#print(prompt_messages)
|
||||||
|
#print("tools"+"*"*20)
|
||||||
|
#print(tools)
|
||||||
|
|
||||||
|
# invoke model
|
||||||
|
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:md = genai.GenerativeModel(model)
|
||||||
|
"""
|
||||||
|
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||||
|
|
||||||
|
return self._get_num_tokens_by_gpt2(prompt)
|
||||||
|
|
||||||
|
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:md = genai.GenerativeModel(model)
|
||||||
|
"""
|
||||||
|
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||||
|
|
||||||
|
return len(prompt)
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||||
|
"""
|
||||||
|
:param messages: List of PromptMessage to combine.
|
||||||
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
|
"""
|
||||||
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
|
text = "".join(
|
||||||
|
self._convert_one_message_to_text(message)
|
||||||
|
for message in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return text.rstrip()
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Setup basic variables
|
||||||
|
# Auth Config
|
||||||
|
try:
|
||||||
|
ping_message = SystemPromptMessage(content="ping")
|
||||||
|
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _generate(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials kwargs
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
# config_kwargs = model_parameters.copy()
|
||||||
|
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||||
|
# if stop:
|
||||||
|
# config_kwargs["stop_sequences"] = stop
|
||||||
|
|
||||||
|
# initialize client
|
||||||
|
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
|
||||||
|
oci_config = copy.deepcopy(oci_config_template)
|
||||||
|
if "oci_config_content" in credentials:
|
||||||
|
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||||
|
config_items = oci_config_content.split("/")
|
||||||
|
if len(config_items) != 5:
|
||||||
|
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||||
|
oci_config["user"] = config_items[0]
|
||||||
|
oci_config["fingerprint"] = config_items[1]
|
||||||
|
oci_config["tenancy"] = config_items[2]
|
||||||
|
oci_config["region"] = config_items[3]
|
||||||
|
oci_config["compartment_id"] = config_items[4]
|
||||||
|
else:
|
||||||
|
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||||
|
if "oci_key_content" in credentials:
|
||||||
|
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||||
|
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||||
|
|
||||||
|
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||||
|
compartment_id = oci_config["compartment_id"]
|
||||||
|
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||||
|
# call embedding model
|
||||||
|
request_args = copy.deepcopy(request_template)
|
||||||
|
request_args["compartmentId"] = compartment_id
|
||||||
|
request_args["servingMode"]["modelId"] = model
|
||||||
|
|
||||||
|
chathistory = []
|
||||||
|
system_prompts = []
|
||||||
|
#if "meta.llama" in model:
|
||||||
|
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
|
||||||
|
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
|
||||||
|
request_args["chatRequest"].update(model_parameters)
|
||||||
|
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
|
||||||
|
presence_penalty = model_parameters.get("presencePenalty", 0)
|
||||||
|
if frequency_penalty > 0 and presence_penalty > 0:
|
||||||
|
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
|
||||||
|
|
||||||
|
# for msg in prompt_messages: # makes message roles strictly alternating
|
||||||
|
# content = self._format_message_to_glm_content(msg)
|
||||||
|
# if history and history[-1]["role"] == content["role"]:
|
||||||
|
# history[-1]["parts"].extend(content["parts"])
|
||||||
|
# else:
|
||||||
|
# history.append(content)
|
||||||
|
|
||||||
|
# temporary not implement the tool call function
|
||||||
|
valid_value = self._is_tool_call_supported(model, stream)
|
||||||
|
if tools is not None and len(tools) > 0:
|
||||||
|
if not valid_value:
|
||||||
|
raise InvokeBadRequestError("Does not support function calling")
|
||||||
|
if model.startswith("cohere"):
|
||||||
|
#print("run cohere " * 10)
|
||||||
|
for message in prompt_messages[:-1]:
|
||||||
|
text = ""
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
text = message.content
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
chathistory.append({"role": "USER", "message": text})
|
||||||
|
else:
|
||||||
|
chathistory.append({"role": "CHATBOT", "message": text})
|
||||||
|
if isinstance(message, SystemPromptMessage):
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
system_prompts.append(message.content)
|
||||||
|
args = {"apiFormat": "COHERE",
|
||||||
|
"preambleOverride": ' '.join(system_prompts),
|
||||||
|
"message": prompt_messages[-1].content,
|
||||||
|
"chatHistory": chathistory, }
|
||||||
|
request_args["chatRequest"].update(args)
|
||||||
|
elif model.startswith("meta"):
|
||||||
|
#print("run meta " * 10)
|
||||||
|
meta_messages = []
|
||||||
|
for message in prompt_messages:
|
||||||
|
text = message.content
|
||||||
|
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
|
||||||
|
args = {"apiFormat": "GENERIC",
|
||||||
|
"messages": meta_messages,
|
||||||
|
"numGenerations": 1,
|
||||||
|
"topK": -1}
|
||||||
|
request_args["chatRequest"].update(args)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
request_args["chatRequest"]["isStream"] = True
|
||||||
|
#print("final request" + "|" * 20)
|
||||||
|
#print(request_args)
|
||||||
|
response = client.chat(request_args)
|
||||||
|
#print(vars(response))
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||||
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
|
"""
|
||||||
|
Handle llm response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: llm response
|
||||||
|
"""
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=response.data.chat_response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# transform response
|
||||||
|
result = LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||||
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle llm stream response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: llm response chunk generator result
|
||||||
|
"""
|
||||||
|
index = -1
|
||||||
|
events = response.data.events()
|
||||||
|
for stream in events:
|
||||||
|
chunk = json.loads(stream.data)
|
||||||
|
#print(chunk)
|
||||||
|
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#for chunk in response:
|
||||||
|
#for part in chunk.parts:
|
||||||
|
#if part.function_call:
|
||||||
|
# assistant_prompt_message.tool_calls = [
|
||||||
|
# AssistantPromptMessage.ToolCall(
|
||||||
|
# id=part.function_call.name,
|
||||||
|
# type='function',
|
||||||
|
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
# name=part.function_call.name,
|
||||||
|
# arguments=json.dumps(dict(part.function_call.args.items()))
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
|
||||||
|
if "finishReason" not in chunk:
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=''
|
||||||
|
)
|
||||||
|
if model.startswith("cohere"):
|
||||||
|
if chunk["text"]:
|
||||||
|
assistant_prompt_message.content += chunk["text"]
|
||||||
|
elif model.startswith("meta"):
|
||||||
|
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
|
||||||
|
index += 1
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=str(chunk["finishReason"]),
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
|
"""
|
||||||
|
Convert a single message to a string.
|
||||||
|
|
||||||
|
:param message: PromptMessage to convert.
|
||||||
|
:return: String representation of the message.
|
||||||
|
"""
|
||||||
|
human_prompt = "\n\nuser:"
|
||||||
|
ai_prompt = "\n\nmodel:"
|
||||||
|
|
||||||
|
content = message.content
|
||||||
|
if isinstance(content, list):
|
||||||
|
content = "".join(
|
||||||
|
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message_text = f"{ai_prompt} {content}"
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message_text = f"{human_prompt} {content}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [],
|
||||||
|
InvokeServerUnavailableError: [],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [],
|
||||||
|
InvokeBadRequestError: []
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
model: meta.llama-3-70b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: meta.llama-3-70b-instruct
|
||||||
|
en_US: meta.llama-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 131072
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
default: 1
|
||||||
|
max: 2.0
|
||||||
|
- name: topP
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.75
|
||||||
|
min: 0
|
||||||
|
max: 1
|
||||||
|
- name: topK
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
min: 0
|
||||||
|
max: 500
|
||||||
|
- name: presencePenalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
min: -2
|
||||||
|
max: 2
|
||||||
|
default: 0
|
||||||
|
- name: frequencyPenalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
min: -2
|
||||||
|
max: 2
|
||||||
|
default: 0
|
||||||
|
- name: maxTokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 600
|
||||||
|
max: 8000
|
||||||
|
pricing:
|
||||||
|
input: '0.015'
|
||||||
|
output: '0.015'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
34
api/core/model_runtime/model_providers/oci/oci.py
Normal file
34
api/core/model_runtime/model_providers/oci/oci.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OCIGENAIProvider(ModelProvider):
|
||||||
|
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Use `cohere.command-r-plus` model for validate,
|
||||||
|
model_instance.validate_credentials(
|
||||||
|
model='cohere.command-r-plus',
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
|
42
api/core/model_runtime/model_providers/oci/oci.yaml
Normal file
42
api/core/model_runtime/model_providers/oci/oci.yaml
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
provider: oci
|
||||||
|
label:
|
||||||
|
en_US: OCIGenerativeAI
|
||||||
|
description:
|
||||||
|
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
|
||||||
|
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.svg
|
||||||
|
background: "#FFFFFF"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API Key from OCI
|
||||||
|
zh_Hans: 从 OCI 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
#- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
#- customizable-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: oci_config_content
|
||||||
|
label:
|
||||||
|
en_US: oci api key config file's content
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||||
|
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||||
|
- variable: oci_key_content
|
||||||
|
label:
|
||||||
|
en_US: oci api key file's content
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
|
||||||
|
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))
|
|
@ -0,0 +1,5 @@
|
||||||
|
- cohere.embed-english-light-v2.0
|
||||||
|
- cohere.embed-english-light-v3.0
|
||||||
|
- cohere.embed-english-v3.0
|
||||||
|
- cohere.embed-multilingual-light-v3.0
|
||||||
|
- cohere.embed-multilingual-v3.0
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: cohere.embed-english-light-v2.0
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 1024
|
||||||
|
max_chunks: 48
|
||||||
|
pricing:
|
||||||
|
input: '0.001'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: cohere.embed-english-light-v3.0
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 384
|
||||||
|
max_chunks: 48
|
||||||
|
pricing:
|
||||||
|
input: '0.001'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: cohere.embed-english-v3.0
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 1024
|
||||||
|
max_chunks: 48
|
||||||
|
pricing:
|
||||||
|
input: '0.001'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: cohere.embed-multilingual-light-v3.0
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 384
|
||||||
|
max_chunks: 48
|
||||||
|
pricing:
|
||||||
|
input: '0.001'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: cohere.embed-multilingual-v3.0
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 1024
|
||||||
|
max_chunks: 48
|
||||||
|
pricing:
|
||||||
|
input: '0.001'
|
||||||
|
unit: '0.0001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,242 @@
|
||||||
|
import base64
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import oci
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
request_template = {
|
||||||
|
"compartmentId": "",
|
||||||
|
"servingMode": {
|
||||||
|
"modelId": "cohere.embed-english-light-v3.0",
|
||||||
|
"servingType": "ON_DEMAND"
|
||||||
|
},
|
||||||
|
"truncate": "NONE",
|
||||||
|
"inputs": [""]
|
||||||
|
}
|
||||||
|
oci_config_template = {
|
||||||
|
"user": "",
|
||||||
|
"fingerprint": "",
|
||||||
|
"tenancy": "",
|
||||||
|
"region": "",
|
||||||
|
"compartment_id": "",
|
||||||
|
"key_content": ""
|
||||||
|
}
|
||||||
|
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
# get model properties
|
||||||
|
context_size = self._get_context_size(model, credentials)
|
||||||
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
indices = []
|
||||||
|
used_tokens = 0
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
|
||||||
|
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||||
|
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
if num_tokens >= context_size:
|
||||||
|
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||||
|
# if num tokens is larger than context length, only use the start
|
||||||
|
inputs.append(text[0: cutoff])
|
||||||
|
else:
|
||||||
|
inputs.append(text)
|
||||||
|
indices += [i]
|
||||||
|
|
||||||
|
batched_embeddings = []
|
||||||
|
_iter = range(0, len(inputs), max_chunks)
|
||||||
|
|
||||||
|
for i in _iter:
|
||||||
|
# call embedding model
|
||||||
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
texts=inputs[i: i + max_chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
used_tokens += embedding_used_tokens
|
||||||
|
batched_embeddings += embeddings_batch
|
||||||
|
|
||||||
|
# calc usage
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=used_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TextEmbeddingResult(
|
||||||
|
embeddings=batched_embeddings,
|
||||||
|
usage=usage,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||||
|
|
||||||
|
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
characters = 0
|
||||||
|
for text in texts:
|
||||||
|
characters += len(text)
|
||||||
|
return characters
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# call embedding model
|
||||||
|
self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
texts=['ping']
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||||
|
"""
|
||||||
|
Invoke embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return: embeddings and used tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
# oci
|
||||||
|
# initialize client
|
||||||
|
oci_config = copy.deepcopy(oci_config_template)
|
||||||
|
if "oci_config_content" in credentials:
|
||||||
|
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||||
|
config_items = oci_config_content.split("/")
|
||||||
|
if len(config_items) != 5:
|
||||||
|
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||||
|
oci_config["user"] = config_items[0]
|
||||||
|
oci_config["fingerprint"] = config_items[1]
|
||||||
|
oci_config["tenancy"] = config_items[2]
|
||||||
|
oci_config["region"] = config_items[3]
|
||||||
|
oci_config["compartment_id"] = config_items[4]
|
||||||
|
else:
|
||||||
|
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||||
|
if "oci_key_content" in credentials:
|
||||||
|
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||||
|
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||||
|
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||||
|
compartment_id = oci_config["compartment_id"]
|
||||||
|
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||||
|
# call embedding model
|
||||||
|
request_args = copy.deepcopy(request_template)
|
||||||
|
request_args["compartmentId"] = compartment_id
|
||||||
|
request_args["servingMode"]["modelId"] = model
|
||||||
|
request_args["inputs"] = texts
|
||||||
|
response = client.embed_text(request_args)
|
||||||
|
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
KeyError
|
||||||
|
]
|
||||||
|
}
|
915
api/poetry.lock
generated
915
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
@ -190,6 +190,7 @@ zhipuai = "1.0.7"
|
||||||
azure-ai-ml = "^1.19.0"
|
azure-ai-ml = "^1.19.0"
|
||||||
azure-ai-inference = "^1.0.0b3"
|
azure-ai-inference = "^1.0.0b3"
|
||||||
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
||||||
|
oci = "^2.133.0"
|
||||||
[tool.poetry.group.indriect.dependencies]
|
[tool.poetry.group.indriect.dependencies]
|
||||||
kaleido = "0.2.1"
|
kaleido = "0.2.1"
|
||||||
rank-bm25 = "~0.2.2"
|
rank-bm25 = "~0.2.2"
|
||||||
|
|
130
api/tests/integration_tests/model_runtime/oci/test_llm.py
Normal file
130
api/tests/integration_tests/model_runtime/oci/test_llm.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = OCILargeLanguageModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="cohere.command-r-plus",
|
||||||
|
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="cohere.command-r-plus",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = OCILargeLanguageModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model="cohere.command-r-plus",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||||
|
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||||
|
stop=["How"],
|
||||||
|
stream=False,
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_stream_model():
|
||||||
|
model = OCILargeLanguageModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model="meta.llama-3-70b-instruct",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||||
|
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||||
|
stream=True,
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, Generator)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
|
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model_with_function():
|
||||||
|
model = OCILargeLanguageModel()
|
||||||
|
|
||||||
|
response = model.invoke(
|
||||||
|
model="cohere.command-r-plus",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||||
|
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||||
|
stream=False,
|
||||||
|
user="abc-123",
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather in a given location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = OCILargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="cohere.command-r-plus",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 18
|
|
@ -0,0 +1,20 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = OCIGENAIProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={})
|
||||||
|
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
}
|
||||||
|
)
|
|
@ -0,0 +1,58 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = OCITextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="cohere.embed-multilingual-v3.0",
|
||||||
|
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="cohere.embed-multilingual-v3.0",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = OCITextEmbeddingModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="cohere.embed-multilingual-v3.0",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 4
|
||||||
|
# assert result.usage.total_tokens == 811
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = OCITextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="cohere.embed-multilingual-v3.0",
|
||||||
|
credentials={
|
||||||
|
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
|
||||||
|
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 2
|
Loading…
Reference in New Issue
Block a user