mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Support for Vertex AI (#4586)
This commit is contained in:
parent
9ae72cdcf4
commit
296887754f
|
@ -2,6 +2,7 @@
|
|||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
- vertex_ai
|
||||
- nvidia
|
||||
- cohere
|
||||
- bedrock
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24px" height="24px"><path d="M20,13.89A.77.77,0,0,0,19,13.73l-7,5.14v.22a.72.72,0,1,1,0,1.43v0a.74.74,0,0,0,.45-.15l7.41-5.47A.76.76,0,0,0,20,13.89Z" style="fill:#669df6"/><path d="M12,20.52a.72.72,0,0,1,0-1.43h0v-.22L5,13.73a.76.76,0,0,0-1,.16.74.74,0,0,0,.16,1l7.41,5.47a.73.73,0,0,0,.44.15v0Z" style="fill:#aecbfa"/><path d="M12,18.34a1.47,1.47,0,1,0,1.47,1.47A1.47,1.47,0,0,0,12,18.34Zm0,2.18a.72.72,0,1,1,.72-.71A.71.71,0,0,1,12,20.52Z" style="fill:#4285f4"/><path d="M6,6.11a.76.76,0,0,1-.75-.75V3.48a.76.76,0,1,1,1.51,0V5.36A.76.76,0,0,1,6,6.11Z" style="fill:#aecbfa"/><circle cx="5.98" cy="12" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="9.79" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="7.57" r="0.76" style="fill:#aecbfa"/><path d="M18,8.31a.76.76,0,0,1-.75-.76V5.67a.75.75,0,1,1,1.5,0V7.55A.75.75,0,0,1,18,8.31Z" style="fill:#4285f4"/><circle cx="18.02" cy="12.01" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="9.76" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="3.48" r="0.76" style="fill:#4285f4"/><path d="M12,15a.76.76,0,0,1-.75-.75V12.34a.76.76,0,0,1,1.51,0v1.89A.76.76,0,0,1,12,15Z" style="fill:#669df6"/><circle cx="12" cy="16.45" r="0.76" style="fill:#669df6"/><circle cx="12" cy="10.14" r="0.76" style="fill:#669df6"/><circle cx="12" cy="7.92" r="0.76" style="fill:#669df6"/><path d="M15,10.54a.76.76,0,0,1-.75-.75V7.91a.76.76,0,1,1,1.51,0V9.79A.76.76,0,0,1,15,10.54Z" style="fill:#4285f4"/><circle cx="15.01" cy="5.69" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="14.19" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="11.97" r="0.76" style="fill:#4285f4"/><circle cx="8.99" cy="14.19" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="7.92" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="5.69" r="0.76" style="fill:#aecbfa"/><path d="M9,12.73A.76.76,0,0,1,8.24,12V10.1a.75.75,0,1,1,1.5,0V12A.75.75,0,0,1,9,12.73Z" style="fill:#aecbfa"/></svg>
|
After Width: | Height: | Size: 1.9 KiB |
15
api/core/model_runtime/model_providers/vertex_ai/_common.py
Normal file
15
api/core/model_runtime/model_providers/vertex_ai/_common.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
class _CommonVertexAi:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,38 @@
|
|||
model: gemini-1.0-pro-vision-001
|
||||
label:
|
||||
en_US: Gemini 1.0 Pro Vision
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,38 @@
|
|||
model: gemini-1.0-pro-002
|
||||
label:
|
||||
en_US: Gemini 1.0 Pro
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32760
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,38 @@
|
|||
model: gemini-1.5-flash-preview-0514
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,39 @@
|
|||
model: gemini-1.5-pro-preview-0514
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
438
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Normal file
438
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Normal file
|
@ -0,0 +1,438 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.api_core.exceptions as exceptions
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = gml.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
return glm.Tool(
|
||||
function_declarations=[
|
||||
glm.FunctionDeclaration(
|
||||
name=tool.name,
|
||||
parameters=glm.Schema(
|
||||
type=glm.Type.OBJECT,
|
||||
properties={
|
||||
key: {
|
||||
'type_': value.get('type', 'string').upper(),
|
||||
'description': value.get('description', ''),
|
||||
'enum': value.get('enum', [])
|
||||
} for key, value in tool.parameters.get('properties', {}).items()
|
||||
},
|
||||
required=tool.parameters.get('required', [])
|
||||
),
|
||||
) for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
|
||||
history = []
|
||||
system_instruction = GEMINI_BLOCK_MODE_PROMPT
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-1.0-pro-vision-001":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages:
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_instruction = msg.content
|
||||
else:
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1].role == content.role:
|
||||
history[-1].parts.extend(content.parts)
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
google_model = glm.GenerativeModel(
|
||||
model_name=model,
|
||||
system_instruction=system_instruction
|
||||
)
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=glm.GenerationConfig(
|
||||
**config_kwargs
|
||||
),
|
||||
stream=stream,
|
||||
safety_settings=safety_settings,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.candidates[0].content.parts[0].text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
for chunk in response:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
|
||||
if part.text:
|
||||
assistant_prompt_message.content += part.text
|
||||
|
||||
if part.function_call:
|
||||
assistant_prompt_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.function_call.name,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps({
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
})
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
index += 1
|
||||
|
||||
if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk.candidates[0].finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = glm.Content(role="user", parts=[])
|
||||
|
||||
if (isinstance(message.content, str)):
|
||||
glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
|
||||
else:
|
||||
parts = []
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
parts.append(glm.Part.from_text(c.data))
|
||||
else:
|
||||
metadata, data = c.data.split(',', 1)
|
||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
||||
parts.append(blob)
|
||||
|
||||
glm_content = glm.Content(role="user", parts=[parts])
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if message.content:
|
||||
glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
|
||||
if message.tool_calls:
|
||||
glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
))])
|
||||
return glm_content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
|
||||
name=message.name,
|
||||
response={
|
||||
"response": message.content
|
||||
}
|
||||
))])
|
||||
return glm_content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = gml.GenerativeModel(model)error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = gml.GenerativeModel(model)rror mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
exceptions.RetryError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
exceptions.ServiceUnavailable,
|
||||
exceptions.InternalServerError,
|
||||
exceptions.BadGateway,
|
||||
exceptions.GatewayTimeout,
|
||||
exceptions.DeadlineExceeded
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
exceptions.ResourceExhausted,
|
||||
exceptions.TooManyRequests
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.PermissionDenied,
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.Forbidden
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
exceptions.BadRequest,
|
||||
exceptions.InvalidArgument,
|
||||
exceptions.FailedPrecondition,
|
||||
exceptions.OutOfRange,
|
||||
exceptions.NotFound,
|
||||
exceptions.MethodNotAllowed,
|
||||
exceptions.Conflict,
|
||||
exceptions.AlreadyExists,
|
||||
exceptions.Aborted,
|
||||
exceptions.LengthRequired,
|
||||
exceptions.PreconditionFailed,
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
]
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
model: text-embedding-004
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
pricing:
|
||||
input: '0.00013'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,8 @@
|
|||
model: text-multilingual-embedding-002
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
pricing:
|
||||
input: '0.00013'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,193 @@
|
|||
import base64
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
|
||||
|
||||
|
||||
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Vertex AI text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings result
|
||||
"""
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
|
||||
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
|
||||
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
client=client,
|
||||
texts=texts
|
||||
)
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=embedding_used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings_batch,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
# calculate the number of tokens in the encoded text
|
||||
tokenized_text = enc.encode(text)
|
||||
total_num_tokens += len(tokenized_text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
|
||||
client = VertexTextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=['ping']
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
response = client.get_embeddings(texts)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for i in range(len(response)):
|
||||
embeddings.append(response[i].values)
|
||||
token_usage += int(response[i].statistics.token_count)
|
||||
|
||||
return embeddings, token_usage
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
|
@ -0,0 +1,31 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAiProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-1.0-pro-002` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='gemini-1.0-pro-002',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
|
@ -0,0 +1,43 @@
|
|||
provider: vertex_ai
|
||||
label:
|
||||
en_US: Vertex AI | Google Cloud Platform
|
||||
description:
|
||||
en_US: Vertex AI in Google Cloud Platform.
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FCFDFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your Access Details from Google
|
||||
url:
|
||||
en_US: https://cloud.google.com/vertex-ai/
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: vertex_project_id
|
||||
label:
|
||||
en_US: Project ID
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
en_US: Enter your Google Cloud Project ID
|
||||
- variable: vertex_location
|
||||
label:
|
||||
en_US: Location
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
en_US: Enter your Google Cloud Location
|
||||
- variable: vertex_service_account_key
|
||||
label:
|
||||
en_US: Service Account Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
en_US: Enter your Google Cloud Service Account Key in base64 format
|
|
@ -84,3 +84,4 @@ pgvecto-rs==0.1.4
|
|||
firecrawl-py==0.0.5
|
||||
oss2==2.18.5
|
||||
pgvector==0.2.5
|
||||
google-cloud-aiplatform==1.49.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user