mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166)
This commit is contained in:
parent
af92f19291
commit
ed37439ef7
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
|
@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
|
|||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
|
|||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
|
@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
|
|||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
||||
|
||||
return app_generate_entity
|
||||
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
|
@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
|
|||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
|
@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
|
|||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
|
@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
|
|||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
parameter_type = "string"
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
|
@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
|
|||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
|
||||
def create_agent_thought(
|
||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_labels_str="{}",
|
||||
tool_meta_str="{}",
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by_role="account",
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
|
@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
|||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
|
@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
|
@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
|
@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
|||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
|
@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
|
@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
|
|||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
|
@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
|
|||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tools = tools.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
|
@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
|
|||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
)
|
||||
tool_call_response.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
result.extend(
|
||||
[
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
|
@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
|
|||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
|
@ -41,7 +41,7 @@ class ModelInstance:
|
|||
configuration=provider_model_bundle.configuration,
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model,
|
||||
credentials=self.credentials
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -54,10 +54,7 @@ class ModelInstance:
|
|||
"""
|
||||
configuration = provider_model_bundle.configuration
|
||||
model_type = provider_model_bundle.model_type_instance.model_type
|
||||
credentials = configuration.get_current_credentials(
|
||||
model_type=model_type,
|
||||
model=model
|
||||
)
|
||||
credentials = configuration.get_current_credentials(model_type=model_type, model=model)
|
||||
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||
|
@ -65,10 +62,9 @@ class ModelInstance:
|
|||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict) -> Optional["LBModelManager"]:
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
:param configuration: provider configuration
|
||||
|
@ -81,8 +77,7 @@ class ModelInstance:
|
|||
current_model_setting = None
|
||||
# check if model is disabled by admin
|
||||
for model_setting in configuration.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if model_setting.model_type == model_type and model_setting.model == model:
|
||||
current_model_setting = model_setting
|
||||
break
|
||||
|
||||
|
@ -95,17 +90,23 @@ class ModelInstance:
|
|||
model_type=model_type,
|
||||
model=model,
|
||||
load_balancing_configs=current_model_setting.load_balancing_configs,
|
||||
managed_credentials=credentials if configuration.custom_configuration.provider else None
|
||||
managed_credentials=credentials if configuration.custom_configuration.provider else None,
|
||||
)
|
||||
|
||||
return lb_model_manager
|
||||
|
||||
return None
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -132,11 +133,12 @@ class ModelInstance:
|
|||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
|
@ -153,11 +155,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -174,7 +175,7 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
|
@ -192,13 +193,17 @@ class ModelInstance:
|
|||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
def invoke_rerank(
|
||||
self,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
|
@ -221,11 +226,10 @@ class ModelInstance:
|
|||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
|
@ -242,11 +246,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -263,11 +266,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
|
@ -288,7 +290,7 @@ class ModelInstance:
|
|||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
@ -312,8 +314,8 @@ class ModelInstance:
|
|||
raise last_exception
|
||||
|
||||
try:
|
||||
if 'credentials' in kwargs:
|
||||
del kwargs['credentials']
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
|
@ -340,9 +342,7 @@ class ModelInstance:
|
|||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
language=language
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
|
||||
|
@ -363,9 +363,7 @@ class ModelManager:
|
|||
return self.get_default_model_instance(tenant_id, model_type)
|
||||
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
@ -386,10 +384,7 @@ class ModelManager:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
default_model_entity = self._provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
if not default_model_entity:
|
||||
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||
|
@ -398,17 +393,20 @@ class ModelManager:
|
|||
tenant_id=tenant_id,
|
||||
provider=default_model_entity.provider.provider,
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
model=default_model_entity.model,
|
||||
)
|
||||
|
||||
|
||||
class LBModelManager:
|
||||
def __init__(self, tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: Optional[dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load balancing model manager
|
||||
:param tenant_id: tenant_id
|
||||
|
@ -439,10 +437,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cache_key = "model_lb_index:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model
|
||||
)
|
||||
|
||||
cooldown_load_balancing_configs = []
|
||||
|
@ -473,10 +468,12 @@ class LBModelManager:
|
|||
|
||||
continue
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
f"model_type: {self._model_type.value}\nmodel: {self._model}")
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
logger.info(
|
||||
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
f"model_type: {self._model_type.value}\nmodel: {self._model}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
@ -490,14 +487,10 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
redis_client.setex(cooldown_cache_key, expire, 'true')
|
||||
redis_client.setex(cooldown_cache_key, expire, "true")
|
||||
|
||||
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
|
||||
"""
|
||||
|
@ -506,11 +499,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
|
@ -518,11 +507,9 @@ class LBModelManager:
|
|||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_config_in_cooldown_and_ttl(tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
config_id: str) -> tuple[bool, int]:
|
||||
def get_config_in_cooldown_and_ttl(
|
||||
tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Get model load balancing config is in cooldown and ttl
|
||||
:param tenant_id: workspace id
|
||||
|
@ -533,11 +520,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
tenant_id,
|
||||
provider,
|
||||
model_type.value,
|
||||
model,
|
||||
config_id
|
||||
tenant_id, provider, model_type.value, model, config_id
|
||||
)
|
||||
|
||||
ttl = redis_client.ttl(cooldown_cache_key)
|
||||
|
|
Loading…
Reference in New Issue
Block a user