Support for Vertex AI (#4586)

This commit is contained in:
Patryk Garstecki 2024-05-24 06:01:40 +02:00 committed by GitHub
parent 9ae72cdcf4
commit 296887754f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 892 additions and 0 deletions

View File

@ -2,6 +2,7 @@
- anthropic
- azure_openai
- google
- vertex_ai
- nvidia
- cohere
- bedrock

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24px" height="24px"><path d="M20,13.89A.77.77,0,0,0,19,13.73l-7,5.14v.22a.72.72,0,1,1,0,1.43v0a.74.74,0,0,0,.45-.15l7.41-5.47A.76.76,0,0,0,20,13.89Z" style="fill:#669df6"/><path d="M12,20.52a.72.72,0,0,1,0-1.43h0v-.22L5,13.73a.76.76,0,0,0-1,.16.74.74,0,0,0,.16,1l7.41,5.47a.73.73,0,0,0,.44.15v0Z" style="fill:#aecbfa"/><path d="M12,18.34a1.47,1.47,0,1,0,1.47,1.47A1.47,1.47,0,0,0,12,18.34Zm0,2.18a.72.72,0,1,1,.72-.71A.71.71,0,0,1,12,20.52Z" style="fill:#4285f4"/><path d="M6,6.11a.76.76,0,0,1-.75-.75V3.48a.76.76,0,1,1,1.51,0V5.36A.76.76,0,0,1,6,6.11Z" style="fill:#aecbfa"/><circle cx="5.98" cy="12" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="9.79" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="7.57" r="0.76" style="fill:#aecbfa"/><path d="M18,8.31a.76.76,0,0,1-.75-.76V5.67a.75.75,0,1,1,1.5,0V7.55A.75.75,0,0,1,18,8.31Z" style="fill:#4285f4"/><circle cx="18.02" cy="12.01" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="9.76" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="3.48" r="0.76" style="fill:#4285f4"/><path d="M12,15a.76.76,0,0,1-.75-.75V12.34a.76.76,0,0,1,1.51,0v1.89A.76.76,0,0,1,12,15Z" style="fill:#669df6"/><circle cx="12" cy="16.45" r="0.76" style="fill:#669df6"/><circle cx="12" cy="10.14" r="0.76" style="fill:#669df6"/><circle cx="12" cy="7.92" r="0.76" style="fill:#669df6"/><path d="M15,10.54a.76.76,0,0,1-.75-.75V7.91a.76.76,0,1,1,1.51,0V9.79A.76.76,0,0,1,15,10.54Z" style="fill:#4285f4"/><circle cx="15.01" cy="5.69" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="14.19" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="11.97" r="0.76" style="fill:#4285f4"/><circle cx="8.99" cy="14.19" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="7.92" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="5.69" r="0.76" style="fill:#aecbfa"/><path d="M9,12.73A.76.76,0,0,1,8.24,12V10.1a.75.75,0,1,1,1.5,0V12A.75.75,0,0,1,9,12.73Z" style="fill:#aecbfa"/></svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -0,0 +1,15 @@
from core.model_runtime.errors.invoke import InvokeError
class _CommonVertexAi:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
pass

View File

@ -0,0 +1,38 @@
model: gemini-1.0-pro-vision-001
label:
en_US: Gemini 1.0 Pro Vision
model_type: llm
features:
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 16384
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,38 @@
model: gemini-1.0-pro-002
label:
en_US: Gemini 1.0 Pro
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32760
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,38 @@
model: gemini-1.5-flash-preview-0514
label:
en_US: Gemini 1.5 Flash
model_type: llm
features:
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,39 @@
model: gemini-1.5-pro-preview-0514
label:
en_US: Gemini 1.5 Pro
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,438 @@
import base64
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
import google.api_core.exceptions as exceptions
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class VertexAiLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = gml.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
name=tool.name,
parameters=glm.Schema(
type=glm.Type.OBJECT,
properties={
key: {
'type_': value.get('type', 'string').upper(),
'description': value.get('description', ''),
'enum': value.get('enum', [])
} for key, value in tool.parameters.get('properties', {}).items()
},
required=tool.parameters.get('required', [])
),
) for tool in tools
]
)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
if stop:
config_kwargs["stop_sequences"] = stop
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
history = []
system_instruction = GEMINI_BLOCK_MODE_PROMPT
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-1.0-pro-vision-001":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages:
if isinstance(msg, SystemPromptMessage):
system_instruction = msg.content
else:
content = self._format_message_to_glm_content(msg)
if history and history[-1].role == content.role:
history[-1].parts.extend(content.parts)
else:
history.append(content)
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
google_model = glm.GenerativeModel(
model_name=model,
system_instruction=system_instruction
)
response = google_model.generate_content(
contents=history,
generation_config=glm.GenerationConfig(
**config_kwargs
),
stream=stream,
safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.candidates[0].content.parts[0].text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
index = -1
for chunk in response:
for part in chunk.candidates[0].content.parts:
assistant_prompt_message = AssistantPromptMessage(
content=''
)
if part.text:
assistant_prompt_message.content += part.text
if part.function_call:
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps({
key: value
for key, value in part.function_call.args.items()
})
)
)
]
index += 1
if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=chunk.candidates[0].finish_reason,
usage=usage
)
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nuser:"
ai_prompt = "\n\nmodel:"
content = message.content
if isinstance(content, list):
content = "".join(
c.data for c in content if c.type != PromptMessageContentType.IMAGE
)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])
if (isinstance(message.content, str)):
glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
else:
parts = []
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
parts.append(glm.Part.from_text(c.data))
else:
metadata, data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
parts.append(blob)
glm_content = glm.Content(role="user", parts=[parts])
return glm_content
elif isinstance(message, AssistantPromptMessage):
if message.content:
glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
if message.tool_calls:
glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
))])
return glm_content
elif isinstance(message, ToolPromptMessage):
glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
name=message.name,
response={
"response": message.content
}
))])
return glm_content
else:
raise ValueError(f"Got unknown type {message}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
The value is the md = gml.GenerativeModel(model)error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = gml.GenerativeModel(model)rror mapping
"""
return {
InvokeConnectionError: [
exceptions.RetryError
],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded
],
InvokeRateLimitError: [
exceptions.ResourceExhausted,
exceptions.TooManyRequests
],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden
],
InvokeBadRequestError: [
exceptions.BadRequest,
exceptions.InvalidArgument,
exceptions.FailedPrecondition,
exceptions.OutOfRange,
exceptions.NotFound,
exceptions.MethodNotAllowed,
exceptions.Conflict,
exceptions.AlreadyExists,
exceptions.Aborted,
exceptions.LengthRequired,
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
]
}

View File

@ -0,0 +1,8 @@
model: text-embedding-004
model_type: text-embedding
model_properties:
context_size: 2048
pricing:
input: '0.00013'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,8 @@
model: text-multilingual-embedding-002
model_type: text-embedding
model_properties:
context_size: 2048
pricing:
input: '0.00013'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,193 @@
import base64
import json
import time
from decimal import Decimal
from typing import Optional
import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
PriceConfig,
PriceType,
)
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
Model class for Vertex AI text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return: embeddings result
"""
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
client = VertexTextEmbeddingModel.from_pretrained(model)
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
client=client,
texts=texts
)
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=embedding_used_tokens
)
return TextEmbeddingResult(
embeddings=embeddings_batch,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
try:
enc = tiktoken.encoding_for_model(model)
except KeyError:
enc = tiktoken.get_encoding("cl100k_base")
total_num_tokens = 0
for text in texts:
# calculate the number of tokens in the encoded text
tokenized_text = enc.encode(text)
total_num_tokens += len(tokenized_text)
return total_num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
client = VertexTextEmbeddingModel.from_pretrained(model)
# call embedding model
self._embedding_invoke(
model=model,
client=client,
texts=['ping']
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
"""
Invoke embedding model
:param model: model name
:param client: model client
:param texts: texts to embed
:return: embeddings and used tokens
"""
response = client.get_embeddings(texts)
embeddings = []
token_usage = 0
for i in range(len(response)):
embeddings.append(response[i].values)
token_usage += int(response[i].statistics.token_count)
return embeddings, token_usage
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity

View File

@ -0,0 +1,31 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class VertexAiProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `gemini-1.0-pro-002` model for validate,
model_instance.validate_credentials(
model='gemini-1.0-pro-002',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,43 @@
provider: vertex_ai
label:
en_US: Vertex AI | Google Cloud Platform
description:
en_US: Vertex AI in Google Cloud Platform.
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#FCFDFF"
help:
title:
en_US: Get your Access Details from Google
url:
en_US: https://cloud.google.com/vertex-ai/
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: vertex_project_id
label:
en_US: Project ID
type: text-input
required: true
placeholder:
en_US: Enter your Google Cloud Project ID
- variable: vertex_location
label:
en_US: Location
type: text-input
required: true
placeholder:
en_US: Enter your Google Cloud Location
- variable: vertex_service_account_key
label:
en_US: Service Account Key
type: secret-input
required: true
placeholder:
en_US: Enter your Google Cloud Service Account Key in base64 format

View File

@ -84,3 +84,4 @@ pgvecto-rs==0.1.4
firecrawl-py==0.0.5
oss2==2.18.5
pgvector==0.2.5
google-cloud-aiplatform==1.49.0