mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
parent
178730266d
commit
2cf1187b32
|
@ -1 +1 @@
|
|||
import core.moderation.base
|
||||
import core.moderation.base
|
||||
|
|
|
@ -25,17 +25,19 @@ from models.model import Message
|
|||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(
|
||||
instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
|
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(
|
||||
chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
|
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip(
|
||||
) or 'I am thinking about how to help you'
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ''
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(
|
||||
scratchpad.action.action_input)
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
|
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={
|
||||
scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={
|
||||
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
|
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint=''
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
|
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
|
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
|
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
|
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
|
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
if not current_scratchpad:
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
|
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(
|
||||
message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
|
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
|
|
@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
|
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
|
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
|
|
|
@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
|
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
|
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
|
|
@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
|
|||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
|
@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
|
|||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
|
|||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
|
@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
|
|||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
"action": self.action_name,
|
||||
"action_input": self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
|
@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
|
|||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
|
@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
|
|||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||
FUNCTION_CALLING = "function-calling"
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
|
|
@ -24,11 +24,9 @@ from models.model import Message
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
|
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
|
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
|
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
|
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
|
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
|
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
|
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls=[
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
|
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
|
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
|
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response['tool_response'] is not None:
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
content=tool_response["tool_response"],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
|
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
|
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
|
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
|
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
|
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
|
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
|
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
|
|
@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
|||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
|
@ -22,7 +23,7 @@ class CotAgentOutputParser:
|
|||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
|
@ -33,37 +34,37 @@ class CotAgentOutputParser:
|
|||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_cache = ""
|
||||
action_str = "action:"
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_cache = ""
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict['usage'] = response.delta.usage
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
@ -72,24 +73,24 @@ class CotAgentOutputParser:
|
|||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == '`':
|
||||
if delta == "`":
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
code_block_cache = ""
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
@ -97,7 +98,7 @@ class CotAgentOutputParser:
|
|||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
|
@ -105,18 +106,18 @@ class CotAgentOutputParser:
|
|||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
@ -124,7 +125,7 @@ class CotAgentOutputParser:
|
|||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
|
@ -132,31 +133,31 @@ class CotAgentOutputParser:
|
|||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
elif delta == "}":
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
|
@ -172,12 +173,12 @@ class CotAgentOutputParser:
|
|||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
|
@ -186,4 +187,3 @@ class CotAgentOutputParser:
|
|||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
|
|
|
@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
|||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
"english": {
|
||||
"chat": {
|
||||
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
"completion": {
|
||||
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,34 +26,24 @@ class BaseAppConfigManager:
|
|||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||
config=config_dict,
|
||||
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
)
|
||||
|
||||
additional_features.opening_statement, additional_features.suggested_questions = \
|
||||
OpeningStatementConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||
OpeningStatementConfigManager.convert(config=config_dict)
|
||||
)
|
||||
|
||||
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
||||
|
||||
return additional_features
|
||||
|
|
|
@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
|
|||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
||||
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
if sensitive_word_avoidance_dict.get('enabled'):
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
|
||||
-> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
|
|||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config
|
||||
)
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
return config, ["sensitive_word_avoidance"]
|
||||
|
|
|
@ -12,67 +12,70 @@ class AgentConfigManager:
|
|||
|
||||
:param config: model config args
|
||||
"""
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode']:
|
||||
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
elif agent_strategy == "cot" or agent_strategy == "react":
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config['model']['provider'] == 'openai':
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool.get('tool_parameters', {})
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
if 'strategy' in config['agent_mode'] and \
|
||||
config['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
|
||||
"react_router",
|
||||
"router",
|
||||
]:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion'][
|
||||
'agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
|
||||
return AgentEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -15,39 +15,38 @@ class DatasetConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
dataset_ids = []
|
||||
if 'datasets' in config.get('dataset_configs', {}):
|
||||
datasets = config.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
if "datasets" in config.get("dataset_configs", {}):
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
for dataset in datasets.get("datasets", []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset['dataset']
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
dataset_id = dataset.get("id", None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode'] \
|
||||
and config['agent_mode']['enabled']:
|
||||
if (
|
||||
"agent_mode" in config
|
||||
and config["agent_mode"]
|
||||
and "enabled" in config["agent_mode"]
|
||||
and config["agent_mode"]["enabled"]
|
||||
):
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
@ -55,30 +54,28 @@ class DatasetConfigManager:
|
|||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_id = tool_item["id"]
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
if 'dataset_configs' in config and config.get('dataset_configs'):
|
||||
dataset_configs = config.get('dataset_configs')
|
||||
if "dataset_configs" in config and config.get("dataset_configs"):
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {
|
||||
'retrieval_model': 'multiple'
|
||||
}
|
||||
query_variable = config.get('dataset_query_variable')
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
)
|
||||
)
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return DatasetEntity(
|
||||
|
@ -86,15 +83,15 @@ class DatasetConfigManager:
|
|||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get('top_k', 4),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
)
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold"),
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -111,13 +108,10 @@ class DatasetConfigManager:
|
|||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {
|
||||
"strategy": "router",
|
||||
"datasets": []
|
||||
}
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
@ -125,8 +119,9 @@ class DatasetConfigManager:
|
|||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = (config.get("dataset_configs")
|
||||
and config["dataset_configs"].get("datasets", {}).get("datasets"))
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
@ -148,10 +143,7 @@ class DatasetConfigManager:
|
|||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
@ -188,7 +180,7 @@ class DatasetConfigManager:
|
|||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
|
|
@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
|
|||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig,
|
||||
skip_check: bool = False) \
|
||||
-> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
|
@ -25,9 +23,7 @@ class ModelConfigConverter:
|
|||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
|
@ -38,8 +34,7 @@ class ModelConfigConverter:
|
|||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.model
|
||||
model_type=ModelType.LLM, model=model_config.model
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
|
@ -51,8 +46,7 @@ class ModelConfigConverter:
|
|||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model,
|
||||
model_type=ModelType.LLM
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
|
@ -69,24 +63,18 @@ class ModelConfigConverter:
|
|||
# model config
|
||||
completion_params = model_config.parameters
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=model_config.model,
|
||||
credentials=model_credentials
|
||||
)
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_config.model,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
@ -13,23 +13,23 @@ class ModelConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
# model config
|
||||
model_config = config.get('model')
|
||||
model_config = config.get("model")
|
||||
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get('completion_params')
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.get('mode')
|
||||
model_mode = model_config.get("mode")
|
||||
|
||||
return ModelConfigEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
mode=model_mode,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
|
@ -43,7 +43,7 @@ class ModelConfigManager:
|
|||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if 'model' not in config:
|
||||
if "model" not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
|
@ -52,17 +52,16 @@ class ModelConfigManager:
|
|||
# model.provider
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not models:
|
||||
|
@ -80,12 +79,12 @@ class ModelConfigManager:
|
|||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
config["model"]["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
if "completion_params" not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
|
@ -101,7 +100,7 @@ class ModelConfigManager:
|
|||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
if "stop" not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
|
|
@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
|
|||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = config.get("pre_prompt", "")
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = config.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
|
|||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
|
|||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config['prompt_type'] not in prompt_type_vals:
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
# chat_prompt_config
|
||||
|
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
|
|||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required "
|
||||
"when prompt_type is advanced")
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
|
||||
)
|
||||
|
||||
model_mode_vals = [mode.value for mode in ModelMode]
|
||||
if config['model']["mode"] not in model_mode_vals:
|
||||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
|
|
@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
|
|||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
external_data_tools = config.get("external_data_tools", [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
variable=external_data_tool["variable"],
|
||||
type=external_data_tool["type"],
|
||||
config=external_data_tool["config"],
|
||||
)
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variables in config.get('user_input_form', []):
|
||||
for variables in config.get("user_input_form", []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
if "config" not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in [
|
||||
|
@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
|
|||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
|
|||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
if "label" not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
if "variable" not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
|
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
|
|||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
if "required" not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
if "options" not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
return config, ["user_input_form"]
|
||||
|
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
|
|||
typ = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
|
|
|
@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
|
|||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
|
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
|
|||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
|
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
|||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
|
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
|||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
|
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
|
|||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
SIMPLE = 'simple'
|
||||
ADVANCED = 'advanced'
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptType':
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt type value {value}')
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
|
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
|
|||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
SINGLE = 'single'
|
||||
MULTIPLE = 'multiple'
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = .0
|
||||
rerank_mode: Optional[str] = 'reranking_model'
|
||||
score_threshold: Optional[float] = 0.0
|
||||
rerank_mode: Optional[str] = "reranking_model"
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
|
||||
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
|
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
|||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
|
|||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
|
|||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
|
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
|
|||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
Application Config Entity.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
|
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
|
|||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
ARGS = 'args'
|
||||
APP_LATEST_CONFIG = 'app-latest-config'
|
||||
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
|
||||
|
||||
ARGS = "args"
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Easy UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
|
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
|||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
|
|
|
@ -13,21 +13,19 @@ class FileUploadConfigManager:
|
|||
:param config: model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
file_upload_dict = config.get('file_upload')
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get('image'):
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config['detail'] = file_upload_dict['image']['detail']
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
|
||||
return FileExtraConfig(
|
||||
image_config=image_config
|
||||
)
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -49,21 +47,21 @@ class FileUploadConfigManager:
|
|||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in ["high", "low"]:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
if method not in ["remote_url", "local_file"]:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
|
|
@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get('more_like_this')
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get('enabled'):
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
|
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
|
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get('opening_statement')
|
||||
opening_statement = config.get("opening_statement")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get('suggested_questions')
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
|
|||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get('retriever_resource')
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
if retriever_resource_dict.get('enabled'):
|
||||
if retriever_resource_dict.get("enabled"):
|
||||
show_retrieve_source = True
|
||||
|
||||
return show_retrieve_source
|
||||
|
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("retriever_resource"):
|
||||
config["retriever_resource"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["retriever_resource"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
|
|
@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
speech_to_text = False
|
||||
speech_to_text_dict = config.get('speech_to_text')
|
||||
speech_to_text_dict = config.get("speech_to_text")
|
||||
if speech_to_text_dict:
|
||||
if speech_to_text_dict.get('enabled'):
|
||||
if speech_to_text_dict.get("enabled"):
|
||||
speech_to_text = True
|
||||
|
||||
return speech_to_text
|
||||
|
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("speech_to_text"):
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["speech_to_text"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
|
|
@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
suggested_questions_after_answer = False
|
||||
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
|
||||
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
||||
if suggested_questions_after_answer_dict:
|
||||
if suggested_questions_after_answer_dict.get('enabled'):
|
||||
if suggested_questions_after_answer_dict.get("enabled"):
|
||||
suggested_questions_after_answer = True
|
||||
|
||||
return suggested_questions_after_answer
|
||||
|
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("suggested_questions_after_answer"):
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["suggested_questions_after_answer"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not \
|
||||
config["suggested_questions_after_answer"]["enabled"]:
|
||||
if (
|
||||
"enabled" not in config["suggested_questions_after_answer"]
|
||||
or not config["suggested_questions_after_answer"]["enabled"]
|
||||
):
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
|
|
|
@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
|
|||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
text_to_speech_dict = config.get("text_to_speech")
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
if text_to_speech_dict.get("enabled"):
|
||||
text_to_speech = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
enabled=text_to_speech_dict.get("enabled"),
|
||||
voice=text_to_speech_dict.get("voice"),
|
||||
language=text_to_speech_dict.get("language"),
|
||||
)
|
||||
|
||||
return text_to_speech
|
||||
|
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("text_to_speech"):
|
||||
config["text_to_speech"] = {
|
||||
"enabled": False,
|
||||
"voice": "",
|
||||
"language": ""
|
||||
}
|
||||
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
|
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
|||
"""
|
||||
Advanced Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
||||
|
|
|
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
|
@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
conversation = self._get_conversation_by_user(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
|
@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
|
@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_config=app_config,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=None,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
|
@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": contextvars.copy_context(),
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
|||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
|
|||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
audio_queue.put(AudioTrunk("finish", b""))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
|
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
|
|||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
|
@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
|
|||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
self.msg_text += message.event.outputs.get('output', '')
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
|
|
@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
|
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
|
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
|
@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
|
@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._complete_with_stream_output(
|
||||
text=str(e),
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
)
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
|
@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
self._publish_event(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
|
||||
)
|
||||
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
|
||||
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _complete_with_stream_output(self,
|
||||
text: str,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
)
|
||||
)
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
self._publish_event(
|
||||
QueueStopEvent(stopped_by=stopped_by)
|
||||
)
|
||||
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
|
||||
|
|
|
@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
|
@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
|
@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
|
@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras['metadata'] = stream_response.metadata
|
||||
extras["metadata"] = stream_response.metadata
|
||||
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
|
@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
message_id=self._message.id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
|
@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
|
@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
|
@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
db.session.close()
|
||||
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
|
@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
|
@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
|
@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
# Save message
|
||||
|
@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
|
@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
|
@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
|
@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
|
@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
|
@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Agent Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
|
||||
class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
:param app_model: app model
|
||||
|
@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
agent=AgentConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
agent=AgentConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# dataset configs
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [member.value for member in
|
||||
list(PlanningStrategy.__members__.values())]:
|
||||
if config["agent_mode"]["strategy"] not in [
|
||||
member.value for member in list(PlanningStrategy.__members__.values())
|
||||
]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if not config["agent_mode"].get("tools"):
|
||||
|
@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
|
|
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
|||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not stream:
|
||||
raise ValueError('Agent Chat App does not support blocking mode')
|
||||
raise ValueError("Agent Chat App does not support blocking mode")
|
||||
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = AgentChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self, flask_app: Flask,
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
|
@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
|
|||
"""
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: AgentChatAppGenerateEntity,
|
||||
self,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
|
@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
|
@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
|
|||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
|
@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
|
|||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
|
@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
|
@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
|
|||
agent_entity = app_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id)
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
|
@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
invoke_result = runner.run(
|
||||
|
@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
|
|||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id
|
||||
).first()
|
||||
tool_variables: ToolConversationVariables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
|
@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
|
|||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str='[]',
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
|
@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
|
|||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
|
|
@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
|
|||
_blocking_response_type: type[AppBlockingResponse]
|
||||
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
|
@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
|
|||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
|
|||
:return:
|
||||
"""
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in metadata:
|
||||
metadata['retriever_resources'] = []
|
||||
for resource in metadata['retriever_resources']:
|
||||
metadata['retriever_resources'].append({
|
||||
'segment_id': resource['segment_id'],
|
||||
'position': resource['position'],
|
||||
'document_name': resource['document_name'],
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
if "retriever_resources" in metadata:
|
||||
metadata["retriever_resources"] = []
|
||||
for resource in metadata["retriever_resources"]:
|
||||
metadata["retriever_resources"].append(
|
||||
{
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"content": resource["content"],
|
||||
}
|
||||
)
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in metadata:
|
||||
del metadata['annotation_reply']
|
||||
if "annotation_reply" in metadata:
|
||||
del metadata["annotation_reply"]
|
||||
|
||||
# show usage
|
||||
if 'usage' in metadata:
|
||||
del metadata['usage']
|
||||
if "usage" in metadata:
|
||||
del metadata["usage"]
|
||||
|
||||
return metadata
|
||||
|
||||
|
@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
|
|||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
"code": "provider_quota_exceeded",
|
||||
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
"status": 400,
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
|
||||
InvokeError: {"code": "completion_request_error", "status": 400},
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
|
@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
|
|||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
data.setdefault("message", getattr(e, "description", str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
"status": 500,
|
||||
}
|
||||
|
||||
return data
|
||||
|
|
|
@ -16,10 +16,10 @@ class BaseAppGenerator:
|
|||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
return var.default or ""
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
|
@ -34,7 +34,7 @@ class BaseAppGenerator:
|
|||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
|
@ -43,14 +43,14 @@ class BaseAppGenerator:
|
|||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
|
||||
return user_input_value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace('\x00', '')
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
|
|
@ -24,9 +24,7 @@ class PublishFrom(Enum):
|
|||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
|
@ -34,9 +32,10 @@ class AppQueueManager:
|
|||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
|
||||
f"{user_prefix}-{self._user_id}")
|
||||
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
|
@ -66,8 +65,7 @@ class AppQueueManager:
|
|||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
|
@ -88,9 +86,7 @@ class AppQueueManager:
|
|||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
|
@ -122,8 +118,8 @@ class AppQueueManager:
|
|||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
||||
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
|
@ -168,9 +164,11 @@ class AppQueueManager:
|
|||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
|
|
|
@ -31,12 +31,15 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
|
@ -49,18 +52,20 @@ class AppRunner:
|
|||
"""
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
@ -75,36 +80,39 @@ class AppRunner:
|
|||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
||||
)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
def recalc_llm_max_tokens(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
@ -112,27 +120,28 @@ class AppRunner:
|
|||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def organize_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def organize_prompt_messages(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
|
@ -152,60 +161,54 @@ class AppRunner:
|
|||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
prompt_template = CompletionModelPromptTemplate(
|
||||
text=advanced_completion_prompt_template.prompt
|
||||
)
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
memory_config.role_prefix = MemoryConfig.RolePrefix(
|
||||
user=advanced_completion_prompt_template.role_prefix.user,
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(
|
||||
text=message.text,
|
||||
role=message.role
|
||||
))
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
def direct_output(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
|
@ -222,17 +225,10 @@ class AppRunner:
|
|||
chunk = LLMResultChunk(
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=chunk
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
|
@ -242,15 +238,19 @@ class AppRunner:
|
|||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
),
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False) -> None:
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
@ -260,21 +260,13 @@ class AppRunner:
|
|||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
|
@ -285,12 +277,13 @@ class AppRunner:
|
|||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=invoke_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
@ -300,21 +293,13 @@ class AppRunner:
|
|||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueAgentMessageEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
|
@ -331,25 +316,24 @@ class AppRunner:
|
|||
usage = LLMUsage.empty_usage()
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=llm_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
|
@ -367,14 +351,17 @@ class AppRunner:
|
|||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
def check_hosting_moderation(
|
||||
self,
|
||||
application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -384,8 +371,7 @@ class AppRunner:
|
|||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
|
@ -393,18 +379,20 @@ class AppRunner:
|
|||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
def fill_in_inputs_from_external_data_tools(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
@ -417,18 +405,12 @@ class AppRunner:
|
|||
"""
|
||||
external_data_fetch_feature = ExternalDataFetch()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
def query_app_annotations_to_reply(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
|
@ -440,9 +422,5 @@ class AppRunner:
|
|||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
|
||||
)
|
||||
|
|
|
@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
:param app_model: app model
|
||||
|
@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception('override_config_dict is required when config_from is ARGS')
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
|
||||
config_dict = override_config_dict
|
||||
|
||||
|
@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
|
@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
|
@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
|
|
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
|||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
|
|||
Chat Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
|
@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
|
|||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
|
@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
|
|||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
|
@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
|
@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
|
|||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
|
@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
|
|||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
|
|||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
|
|
@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
|||
"""
|
||||
Completion App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
:param app_model: app model
|
||||
|
@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
|
@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
|
@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
|
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
|||
class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
|
@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
|
@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
conversation = None
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
|
@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
finally:
|
||||
db.session.close()
|
||||
|
||||
def generate_more_like_this(self, app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
def generate_more_like_this(
|
||||
self,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
app_model_config = message.app_model_config
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict['model']
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
override_model_config_dict['model'] = model_dict
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
message.files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
|
@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={}
|
||||
extras={},
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
|
|
@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
|
|||
Completion Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
|
@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
|
|||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# moderation
|
||||
|
@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
|
@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
|
@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
|
|||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_config = app_config.dataset
|
||||
|
@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
|
|||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
|
|||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context
|
||||
context=context,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
|
|||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
|
@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
|
|
@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Handle response.
|
||||
|
@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
|
||||
user: Union[Account, EndUser]) -> Conversation:
|
||||
def _get_conversation_by_user(
|
||||
self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
|
||||
) -> Conversation:
|
||||
conversation_filter = [
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == 'normal'
|
||||
Conversation.status == "normal",
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_app_model_config(self, app_model: App,
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> AppModelConfig:
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
return app_model_config
|
||||
|
||||
def _init_generate_records(self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> tuple[Conversation, Message]:
|
||||
def _init_generate_records(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
conversation: Optional[Conversation] = None,
|
||||
) -> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
from_source = 'api'
|
||||
from_source = "api"
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = 'console'
|
||||
from_source = "console"
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
|
@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
model_provider = application_generate_entity.model_conf.provider
|
||||
model_id = application_generate_entity.model_conf.model
|
||||
override_model_configs = None
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
|
||||
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT,
|
||||
AppMode.COMPLETION,
|
||||
]:
|
||||
override_model_configs = app_config.app_model_config_dict
|
||||
|
||||
# get conversation introduction
|
||||
|
@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name='New conversation',
|
||||
name="New conversation",
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status='normal',
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
|
@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
|
@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
belongs_to="user",
|
||||
url=file.url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
|
@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.id == message_id)
|
||||
.first()
|
||||
)
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
return message
|
||||
|
|
|
@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
|
|||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
def __init__(
|
||||
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
||||
) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._conversation_id = str(conversation_id)
|
||||
|
@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
|
@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueAdvancedChatMessageEndEvent):
|
||||
if isinstance(
|
||||
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
|
|||
"""
|
||||
Workflow App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
|
|
@ -34,26 +34,28 @@ logger = logging.getLogger(__name__)
|
|||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
|
@ -65,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
@ -79,26 +81,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
inputs = args["inputs"]
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
|
@ -114,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -125,18 +120,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, *,
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
@ -154,17 +150,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -177,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -199,16 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
|
@ -219,13 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
|
@ -235,14 +222,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
|
@ -259,7 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
@ -267,14 +257,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -283,14 +272,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) -> Union[
|
||||
WorkflowAppBlockingResponse,
|
||||
Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -306,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
|
|||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
app_mode: str) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
|
|||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
|
|
|
@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
|
@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
|
@ -120,12 +119,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
|
@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
|
|
@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> WorkflowAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
|
@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
)
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -158,10 +160,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
|
@ -171,17 +170,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
|
@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
|
@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
|
||||
outputs=json.dumps(event.outputs)
|
||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
||||
else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
|
@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
|
@ -394,7 +383,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
@ -417,7 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
|
||||
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_app_log)
|
||||
|
@ -431,8 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
:return:
|
||||
"""
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text)
|
||||
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
|
@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
|
@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
|
@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
in_iteration_id=event.in_iteration_id
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
|
@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
|
@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
|
@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
|
@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
|
@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
|
@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
|
@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
@ -30,169 +30,145 @@ _TEXT_COLOR_MAPPING = {
|
|||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color='pink')
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color='green')
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
event=event
|
||||
)
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
event=event
|
||||
)
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='yellow')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='green')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='green')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='green')
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='green')
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color='green')
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='red')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='red')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='red')
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='red')
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
|
||||
)
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
color = "red"
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
|
@ -201,43 +177,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
) -> None:
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f'{text_to_print}', end=end)
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
|
|
|
@ -15,13 +15,14 @@ class InvokeFrom(Enum):
|
|||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
SERVICE_API = 'service-api'
|
||||
WEB_APP = 'web-app'
|
||||
EXPLORE = 'explore'
|
||||
DEBUGGER = 'debugger'
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'InvokeFrom':
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
@ -31,7 +32,7 @@ class InvokeFrom(Enum):
|
|||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid invoke from value {value}')
|
||||
raise ValueError(f"invalid invoke from value {value}")
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""
|
||||
|
@ -40,21 +41,22 @@ class InvokeFrom(Enum):
|
|||
:return: source
|
||||
"""
|
||||
if self == InvokeFrom.WEB_APP:
|
||||
return 'web_app'
|
||||
return "web_app"
|
||||
elif self == InvokeFrom.DEBUGGER:
|
||||
return 'dev'
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return 'explore_app'
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return 'api'
|
||||
return "api"
|
||||
|
||||
return 'dev'
|
||||
return "dev"
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
Model Config With Credentials Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
model_schema: AIModelEntity
|
||||
|
@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
|
|||
"""
|
||||
App Generate Entity.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
|
@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Completion Application Generate Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
|
@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Workflow Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
|
@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ class QueueEvent(str, Enum):
|
|||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
|
||||
LLM_CHUNK = "llm_chunk"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
|
@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel):
|
|||
"""
|
||||
QueueEvent abstract entity
|
||||
"""
|
||||
|
||||
event: QueueEvent
|
||||
|
||||
|
||||
|
@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent):
|
|||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueIterationStartEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationStartEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
|
@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationNextEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
|
@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
@field_validator('output', mode='before')
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
|
@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError('output must be a valid type')
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
|
@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueTextChunkEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
|
@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
|
||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageReplaceEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
|
||||
text: str
|
||||
|
||||
|
@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueRetrieverResourcesEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
|
@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueAnnotationReplyEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
|
||||
message_annotation_id: str
|
||||
|
||||
|
@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueMessageEndEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_END
|
||||
llm_result: Optional[LLMResult] = None
|
||||
|
||||
|
@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueAdvancedChatMessageEndEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
|
||||
|
||||
|
||||
|
@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
|
||||
|
@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueNodeStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_execution_id: str
|
||||
|
@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueNodeSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_execution_id: str
|
||||
|
@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
|
@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
|
@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_FILE
|
||||
message_file_id: str
|
||||
|
||||
|
@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueErrorEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Any = None
|
||||
|
||||
|
@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
|
|||
"""
|
||||
QueuePingEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PING
|
||||
|
||||
|
||||
|
@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueStopEvent entity
|
||||
"""
|
||||
|
||||
class StopBy(Enum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
|
@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent):
|
|||
To stop reason
|
||||
"""
|
||||
reason_mapping = {
|
||||
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||
QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
|
||||
}
|
||||
|
||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage abstract entity
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
app_mode: str
|
||||
event: AppQueueEvent
|
||||
|
@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
|
|||
"""
|
||||
MessageQueueMessage entity
|
||||
"""
|
||||
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
|
||||
|
@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
"""
|
||||
WorkflowQueueMessage entity
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
|
@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
|
@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
|
|
|
@ -12,6 +12,7 @@ class TaskState(BaseModel):
|
|||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
|
@ -19,6 +20,7 @@ class EasyUITaskState(TaskState):
|
|||
"""
|
||||
EasyUITaskState entity
|
||||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
|
||||
|
||||
|
@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState):
|
|||
"""
|
||||
WorkflowTaskState entity
|
||||
"""
|
||||
|
||||
answer: str = ""
|
||||
|
||||
|
||||
|
@ -33,6 +36,7 @@ class StreamEvent(Enum):
|
|||
"""
|
||||
Stream event
|
||||
"""
|
||||
|
||||
PING = "ping"
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
|
@ -60,6 +64,7 @@ class StreamResponse(BaseModel):
|
|||
"""
|
||||
StreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
|
@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
|
|||
"""
|
||||
ErrorStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.ERROR
|
||||
err: Exception
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
|
@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
|
@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
|
@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageEndStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = {}
|
||||
|
@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageFileStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_FILE
|
||||
id: str
|
||||
type: str
|
||||
|
@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
|
|||
"""
|
||||
MessageReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
|
||||
answer: str
|
||||
|
||||
|
@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
|||
"""
|
||||
AgentThoughtStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_THOUGHT
|
||||
id: str
|
||||
position: int
|
||||
|
@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
|
|||
"""
|
||||
AgentMessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
|
@ -162,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
|
@ -182,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
|
@ -210,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
|
@ -249,7 +266,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -262,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
|
@ -315,9 +333,9 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
@ -328,6 +346,7 @@ class ParallelBranchStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
|
@ -349,6 +368,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
|
@ -372,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
|
@ -397,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
|
@ -422,6 +444,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
|
@ -454,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
|
@ -469,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_REPLACE
|
||||
|
@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse):
|
|||
"""
|
||||
PingStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.PING
|
||||
|
||||
|
||||
|
@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel):
|
|||
"""
|
||||
AppStreamResponse entity
|
||||
"""
|
||||
|
||||
stream_response: StreamResponse
|
||||
|
||||
|
||||
|
@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
|
|||
"""
|
||||
ChatbotAppStreamResponse entity
|
||||
"""
|
||||
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int
|
||||
|
@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
|
|||
"""
|
||||
CompletionAppStreamResponse entity
|
||||
"""
|
||||
|
||||
message_id: str
|
||||
created_at: int
|
||||
|
||||
|
@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
|||
"""
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel):
|
|||
"""
|
||||
AppBlockingResponse entity
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
@ -532,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
conversation_id: str
|
||||
|
@ -552,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
message_id: str
|
||||
|
@ -571,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
|
|
|
@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AnnotationReplyFeature:
|
||||
def query(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
def query(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
|
@ -27,8 +25,9 @@ class AnnotationReplyFeature:
|
|||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_record.id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
@ -41,55 +40,50 @@ class AnnotationReplyFeature:
|
|||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
|
||||
)
|
||||
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation_id = documents[0].metadata["annotation_id"]
|
||||
score = documents[0].metadata["score"]
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
|
||||
from_source = 'api'
|
||||
from_source = "api"
|
||||
else:
|
||||
from_source = 'console'
|
||||
from_source = "console"
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
message.id,
|
||||
from_source,
|
||||
score)
|
||||
AppAnnotationService.add_annotation_history(
|
||||
annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
message.id,
|
||||
from_source,
|
||||
score,
|
||||
)
|
||||
|
||||
return annotation
|
||||
except Exception as e:
|
||||
logger.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
logger.warning(f"Query annotation failed, exception: {str(e)}.")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
|
|
@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class HostingModerationFeature:
|
||||
def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
def check(
|
||||
self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
|
||||
) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -23,9 +24,6 @@ class HostingModerationFeature:
|
|||
if isinstance(prompt_message.content, str):
|
||||
text += prompt_message.content + "\n"
|
||||
|
||||
moderation_result = moderation.check_moderation(
|
||||
model_config,
|
||||
text
|
||||
)
|
||||
moderation_result = moderation.check_moderation(model_config, text)
|
||||
|
||||
return moderation_result
|
||||
|
|
|
@ -19,7 +19,7 @@ class RateLimit:
|
|||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
|
@ -27,13 +27,13 @@ class RateLimit:
|
|||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
if hasattr(self, "initialized"):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.last_recalculate_time = float("-inf")
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
|
@ -46,7 +46,7 @@ class RateLimit:
|
|||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
|
@ -54,8 +54,11 @@ class RateLimit:
|
|||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
timeout_requests = [
|
||||
k
|
||||
for k, v in request_details.items()
|
||||
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
|
||||
]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
|
@ -69,8 +72,10 @@ class RateLimit:
|
|||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
raise AppInvokeQuotaExceededError(
|
||||
"Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests)
|
||||
)
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
|
@ -116,5 +121,5 @@ class RateLimitGenerator:
|
|||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
if self.generator is not None and hasattr(self.generator, "close"):
|
||||
self.generator.close()
|
||||
|
|
|
@ -25,25 +25,25 @@ from .variables import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
'IntegerVariable',
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'StringVariable',
|
||||
'ArrayAnyVariable',
|
||||
'Variable',
|
||||
'SegmentType',
|
||||
'SegmentGroup',
|
||||
'Segment',
|
||||
'NoneSegment',
|
||||
'NoneVariable',
|
||||
'IntegerSegment',
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArrayAnySegment',
|
||||
'StringSegment',
|
||||
'ArrayStringVariable',
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArraySegment',
|
||||
"IntegerVariable",
|
||||
"FloatVariable",
|
||||
"ObjectVariable",
|
||||
"SecretVariable",
|
||||
"StringVariable",
|
||||
"ArrayAnyVariable",
|
||||
"Variable",
|
||||
"SegmentType",
|
||||
"SegmentGroup",
|
||||
"Segment",
|
||||
"NoneSegment",
|
||||
"NoneVariable",
|
||||
"IntegerSegment",
|
||||
"FloatSegment",
|
||||
"ObjectSegment",
|
||||
"ArrayAnySegment",
|
||||
"StringSegment",
|
||||
"ArrayStringVariable",
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectVariable",
|
||||
"ArraySegment",
|
||||
]
|
||||
|
|
|
@ -28,12 +28,12 @@ from .variables import (
|
|||
|
||||
|
||||
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := mapping.get('value_type')) is None:
|
||||
raise VariableError('missing value type')
|
||||
if not mapping.get('name'):
|
||||
raise VariableError('missing name')
|
||||
if (value := mapping.get('value')) is None:
|
||||
raise VariableError('missing value')
|
||||
if (value_type := mapping.get("value_type")) is None:
|
||||
raise VariableError("missing value type")
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
|
@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
|
@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
raise VariableError(f"not supported value type {value_type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
return result
|
||||
|
||||
|
||||
|
@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
|||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
return ArrayAnySegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
raise ValueError(f"not supported value {value}")
|
||||
|
|
|
@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if '.' in part and (value := variable_pool.get(part.split('.'))):
|
||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
|
|
|
@ -8,15 +8,15 @@ class SegmentGroup(Segment):
|
|||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join([segment.text for segment in self.value])
|
||||
return "".join([segment.text for segment in self.value])
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return ''.join([segment.log for segment in self.value])
|
||||
return "".join([segment.log for segment in self.value])
|
||||
|
||||
@property
|
||||
def markdown(self):
|
||||
return ''.join([segment.markdown for segment in self.value])
|
||||
return "".join([segment.markdown for segment in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [segment.to_object() for segment in self.value]
|
||||
|
|
|
@ -14,13 +14,13 @@ class Segment(BaseModel):
|
|||
value_type: SegmentType
|
||||
value: Any
|
||||
|
||||
@field_validator('value_type')
|
||||
@field_validator("value_type")
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
If the value is different, a ValueError is raised.
|
||||
"""
|
||||
if value != cls.model_fields['value_type'].default:
|
||||
if value != cls.model_fields["value_type"].default:
|
||||
raise ValueError("Cannot modify 'value_type'")
|
||||
return value
|
||||
|
||||
|
@ -50,15 +50,15 @@ class NoneSegment(Segment):
|
|||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
|
||||
class StringSegment(Segment):
|
||||
|
@ -76,24 +76,21 @@ class IntegerSegment(Segment):
|
|||
value: int
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
|
@ -101,11 +98,11 @@ class ArraySegment(Segment):
|
|||
def markdown(self) -> str:
|
||||
items = []
|
||||
for item in self.value:
|
||||
if hasattr(item, 'to_markdown'):
|
||||
if hasattr(item, "to_markdown"):
|
||||
items.append(item.to_markdown())
|
||||
else:
|
||||
items.append(str(item))
|
||||
return '\n'.join(items)
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
|
@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
|
|||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@ from enum import Enum
|
|||
|
||||
|
||||
class SegmentType(str, Enum):
|
||||
NONE = 'none'
|
||||
NUMBER = 'number'
|
||||
STRING = 'string'
|
||||
SECRET = 'secret'
|
||||
ARRAY_ANY = 'array[any]'
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
OBJECT = 'object'
|
||||
NONE = "none"
|
||||
NUMBER = "number"
|
||||
STRING = "string"
|
||||
SECRET = "secret"
|
||||
ARRAY_ANY = "array[any]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
OBJECT = "object"
|
||||
|
||||
GROUP = 'group'
|
||||
GROUP = "group"
|
||||
|
|
|
@ -23,11 +23,11 @@ class Variable(Segment):
|
|||
"""
|
||||
|
||||
id: str = Field(
|
||||
default='',
|
||||
default="",
|
||||
description="Unique identity for variable. It's only used by environment variables now.",
|
||||
)
|
||||
name: str
|
||||
description: str = Field(default='', description='Description of the variable.')
|
||||
description: str = Field(default="", description="Description of the variable.")
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
|
@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
|||
pass
|
||||
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
|
|
|
@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
|
|||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(self, application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
|
|||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError('Incorrect API key provided')
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = 'error'
|
||||
refetch_message.status = "error"
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
|
@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
|
|||
:return:
|
||||
"""
|
||||
if isinstance(e, QuotaExceededError):
|
||||
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.")
|
||||
return (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
|
||||
message = getattr(e, 'description', str(e))
|
||||
message = getattr(e, "description", str(e))
|
||||
if not message:
|
||||
message = 'Internal Server Error, please contact support.'
|
||||
message = "Internal Server Error, please contact support."
|
||||
|
||||
return message
|
||||
|
||||
|
@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
|
|||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
return ErrorStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
err=e
|
||||
)
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
|
@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
|
|||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
|
@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
|
|||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
completion = self._output_moderation_handler.moderation_completion(
|
||||
completion=completion,
|
||||
public_event=False
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
|
|
@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
]
|
||||
|
||||
def __init__(self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
|
@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse
|
||||
]:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
|
@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {
|
||||
'usage': jsonable_encoder(self._task_state.llm_result.usage)
|
||||
}
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
|
@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
|
@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
|
@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
|
@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
|
@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
|
@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT
|
||||
] and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
|
@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
model = model_config.model
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_instance.get_llm_num_tokens(
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
|
@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
|
@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
|
@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
:return:
|
||||
"""
|
||||
return AgentMessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
|
||||
|
@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
tool=agent_thought.tool,
|
||||
tool_labels=agent_thought.tool_labels,
|
||||
tool_input=agent_thought.tool_input,
|
||||
message_files=agent_thought.files
|
||||
message_files=agent_thought.files,
|
||||
)
|
||||
|
||||
return None
|
||||
|
@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
|
||||
),
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
|
|||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
|
@ -49,15 +46,18 @@ class MessageCycleManage:
|
|||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation.id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
|
@ -65,17 +65,10 @@ class MessageCycleManage:
|
|||
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self,
|
||||
flask_app: Flask,
|
||||
conversation_id: str,
|
||||
query: str):
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
@ -105,12 +98,9 @@ class MessageCycleManage:
|
|||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
|
||||
return annotation
|
||||
|
@ -124,7 +114,7 @@ class MessageCycleManage:
|
|||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
|
@ -132,27 +122,23 @@ class MessageCycleManage:
|
|||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
# add sign url to local file
|
||||
if message_file.url.startswith('http'):
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
@ -161,8 +147,8 @@ class MessageCycleManage:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or 'user',
|
||||
url=url
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
url=url,
|
||||
)
|
||||
|
||||
return None
|
||||
|
@ -174,11 +160,7 @@ class MessageCycleManage:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
|
@ -186,7 +168,4 @@ class MessageCycleManage:
|
|||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
|
||||
|
|
|
@ -70,14 +70,14 @@ class WorkflowCycleManage:
|
|||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == 'conversation':
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from= (
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
@ -185,20 +185,26 @@ class WorkflowCycleManage:
|
|||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||
).all()
|
||||
running_workflow_node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
|
@ -216,7 +222,9 @@ class WorkflowCycleManage:
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
def _handle_node_execution_start(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
|
@ -333,16 +341,16 @@ class WorkflowCycleManage:
|
|||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
'id': created_by_account.id,
|
||||
'name': created_by_account.name,
|
||||
'email': created_by_account.email,
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
'id': created_by_end_user.id,
|
||||
'user': created_by_end_user.session_id,
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
|
@ -401,7 +409,7 @@ class WorkflowCycleManage:
|
|||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
|
@ -410,10 +418,10 @@ class WorkflowCycleManage:
|
|||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
|
@ -424,7 +432,7 @@ class WorkflowCycleManage:
|
|||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
@ -452,13 +460,10 @@ class WorkflowCycleManage:
|
|||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
|
@ -476,15 +481,15 @@ class WorkflowCycleManage:
|
|||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
|
@ -501,18 +506,15 @@ class WorkflowCycleManage:
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
|
@ -534,10 +536,12 @@ class WorkflowCycleManage:
|
|||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
|
||||
def _workflow_iteration_next_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||
) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
|
@ -559,10 +563,12 @@ class WorkflowCycleManage:
|
|||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
|
||||
def _workflow_iteration_completed_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
|
@ -585,13 +591,13 @@ class WorkflowCycleManage:
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
|
@ -643,7 +649,7 @@ class WorkflowCycleManage:
|
|||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
@ -656,11 +662,10 @@ class WorkflowCycleManage:
|
|||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == workflow_run_id).first()
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f'Workflow run not found: {workflow_run_id}')
|
||||
raise Exception(f"Workflow run not found: {workflow_run_id}")
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
@ -683,6 +688,6 @@ class WorkflowCycleManage:
|
|||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f'Workflow node execution not found: {node_execution_id}')
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
return workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
|
|
@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
|
|||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
||||
|
||||
|
||||
class DifyAgentCallbackHandler(BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
|
||||
color: Optional[str] = ""
|
||||
current_loop: int = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or 'green'
|
||||
self.color = color or "green"
|
||||
self.current_loop = 1
|
||||
|
||||
def on_tool_start(
|
||||
|
@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
|
@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||
)
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
|
||||
|
||||
def on_agent_start(
|
||||
self, thought: str
|
||||
) -> None:
|
||||
def on_agent_start(self, thought: str) -> None:
|
||||
"""Run on agent start."""
|
||||
if thought:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + \
|
||||
str(self.current_loop) + \
|
||||
"\nThought: " + thought + "\n", color=self.color)
|
||||
print_text(
|
||||
"\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
|
||||
color=self.color,
|
||||
)
|
||||
else:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
def on_agent_finish(
|
||||
self, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
|
@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
|
|||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager,
|
||||
app_id: str,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
def __init__(
|
||||
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._app_id = app_id
|
||||
self._message_id = message_id
|
||||
|
@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source="app",
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=('account'
|
||||
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self._user_id
|
||||
created_by_role=(
|
||||
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
||||
),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
|
|||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
|
|||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self._user_id
|
||||
position=item.get("position"),
|
||||
dataset_id=item.get("dataset_id"),
|
||||
dataset_name=item.get("dataset_name"),
|
||||
document_id=item.get("document_id"),
|
||||
document_name=item.get("document_name"),
|
||||
data_source_type=item.get("data_source_type"),
|
||||
segment_id=item.get("segment_id"),
|
||||
score=item.get("score") if "score" in item else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" else None,
|
||||
word_count=item.get("word_count") if "word_count" in item else None,
|
||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||
content=item.get("content"),
|
||||
retriever_from=item.get("retriever_from"),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
|
|||
|
||||
|
||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
|
|
@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider).first()
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
|
@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
|
||||
self._model_instance.credentials)
|
||||
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||
else 1
|
||||
)
|
||||
for i in range(0, len(embedding_queue_texts), max_chunks):
|
||||
batch_texts = embedding_queue_texts[i:i + max_chunks]
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts,
|
||||
user=self._user
|
||||
)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
|
@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
|
|||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logging.exception('Failed transform embedding: ', e)
|
||||
logging.exception("Failed transform embedding: ", e)
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider)
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
|
@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
|
|||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
logger.error("Failed to embed documents: ", ex)
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
|
|||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text],
|
||||
user=self._user
|
||||
)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
|
@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
|
|||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except:
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
logging.exception("Failed to add embedding to redis")
|
||||
|
||||
return embedding_results
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
ROUTER = "router"
|
||||
REACT_ROUTER = "react_router"
|
||||
REACT = "react"
|
||||
FUNCTION_CALL = "function_call"
|
||||
|
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel):
|
|||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
|
|
@ -12,6 +12,7 @@ class ModelStatus(Enum):
|
|||
"""
|
||||
Enum class for model status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
|
@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel):
|
|||
"""
|
||||
Simple provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
|
@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel):
|
|||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types
|
||||
supported_model_types=provider_entity.supported_model_types,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
|||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
|
@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
|||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleModelProviderEntity
|
||||
|
||||
|
||||
|
@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel):
|
|||
"""
|
||||
Default model provider entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
|
@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel):
|
|||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: DefaultModelProviderEntity
|
||||
|
|
|
@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider configuration.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: ProviderEntity
|
||||
preferred_provider_type: ProviderType
|
||||
|
@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
|
|||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (any(len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations)
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||
if (
|
||||
any(
|
||||
len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations
|
||||
)
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
||||
):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||
|
@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
|
|||
if self.model_settings:
|
||||
# check if model is disabled by admin
|
||||
for model_setting in self.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if model_setting.model_type == model_type and model_setting.model == model:
|
||||
if not model_setting.enabled:
|
||||
raise ValueError(f'Model {model} is disabled.')
|
||||
raise ValueError(f"Model {model} is disabled.")
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
restrict_models = []
|
||||
|
@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
|
|||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
if (restrict_model.model_type == model_type
|
||||
and restrict_model.model == model
|
||||
and restrict_model.base_model_name):
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
if (
|
||||
restrict_model.model_type == model_type
|
||||
and restrict_model.model == model
|
||||
and restrict_model.base_model_name
|
||||
):
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
return copy_credentials
|
||||
else:
|
||||
|
@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
current_quota_type = self.system_configuration.current_quota_type
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
||||
None
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
||||
)
|
||||
|
||||
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
||||
SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
if current_quota_configuration.is_valid
|
||||
else SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0)
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
|
@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
|
|||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
if self.provider.provider_credential_schema
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
|
@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if provider_record:
|
||||
|
@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
|
|||
# fix origin data
|
||||
if provider_record.encrypted_config:
|
||||
if not provider_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {
|
||||
"openai_api_key": provider_record.encrypted_config
|
||||
}
|
||||
original_credentials = {"openai_api_key": provider_record.encrypted_config}
|
||||
else:
|
||||
original_credentials = json.loads(provider_record.encrypted_config)
|
||||
else:
|
||||
|
@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
|
|||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
credentials=credentials
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
|
@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
|
@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||
-> Optional[dict]:
|
||||
def get_custom_model_credentials(
|
||||
self, model_type: ModelType, model: str, obfuscated: bool = False
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
|
@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
|
|||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
if self.provider.model_credential_schema
|
||||
else [],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> tuple[ProviderModel, dict]:
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
|
@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if provider_model_record:
|
||||
try:
|
||||
original_credentials = json.loads(
|
||||
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
original_credentials = (
|
||||
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
)
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
|
@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
|
|||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
|
@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
|||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
|
@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
|
@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
|
@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False
|
||||
enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(ProviderModelSetting) \
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
|
@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_config_count = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).count()
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError('Model load balancing configuration must be more than 1.')
|
||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
|
@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
|
@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
|
|||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
preferred_model_provider = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider
|
||||
).first()
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
|
@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
|
|||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value
|
||||
preferred_provider_type=provider_type.value,
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
|
@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = credentials.copy()
|
||||
|
@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
return copy_credentials
|
||||
|
||||
def get_provider_model(self, model_type: ModelType,
|
||||
model: str,
|
||||
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
||||
def get_provider_model(
|
||||
self, model_type: ModelType, model: str, only_active: bool = False
|
||||
) -> Optional[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider model.
|
||||
:param model_type: model type
|
||||
|
@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
return None
|
||||
|
||||
def get_provider_models(self, model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) -> list[ModelWithProviderEntity]:
|
||||
def get_provider_models(
|
||||
self, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider models.
|
||||
:param model_type: model type
|
||||
|
@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
|
@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
|
|||
# resort provider_models
|
||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def _get_system_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
|
@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
|
|||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(restrict_model.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
restrict_model.model,
|
||||
copy_credentials
|
||||
)
|
||||
)
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
|
@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
|
|||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
|
|||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def _get_custom_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
|
@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
|
|||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
load_balancing_enabled=load_balancing_enabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
|
|||
continue
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(model_configuration.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model,
|
||||
model_configuration.credentials
|
||||
)
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
model_configuration.model_type
|
||||
).get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model, model_configuration.credentials
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
|
@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
status = ModelStatus.ACTIVE
|
||||
load_balancing_enabled = False
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
|
|||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
load_balancing_enabled=load_balancing_enabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
|
|||
"""
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
|
||||
def get_models(self,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def get_models(
|
||||
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
|
@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
|
|||
"""
|
||||
Provider model bundle.
|
||||
"""
|
||||
|
||||
configuration: ProviderConfiguration
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True,
|
||||
protected_namespaces=())
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
|
|
@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType
|
|||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
CREDITS = 'credits'
|
||||
TIMES = "times"
|
||||
TOKENS = "tokens"
|
||||
CREDITS = "credits"
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
QUOTA_EXCEEDED = 'quota-exceeded'
|
||||
UNSUPPORTED = 'unsupported'
|
||||
|
||||
ACTIVE = "active"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
|
@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider quota configuration.
|
||||
"""
|
||||
|
||||
quota_type: ProviderQuotaType
|
||||
quota_unit: QuotaUnit
|
||||
quota_limit: int
|
||||
|
@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider system configuration.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
|
@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
|
||||
|
||||
|
@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider custom model configuration.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel):
|
|||
"""
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
|
||||
|
@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
|||
"""
|
||||
Class for model load balancing configuration.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
|
@ -93,6 +100,7 @@ class ModelSettings(BaseModel):
|
|||
"""
|
||||
Model class for model settings.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
|
@ -11,6 +12,7 @@ class LLMError(Exception):
|
|||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
|
@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception):
|
|||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -28,6 +31,7 @@ class QuotaExceededError(Exception):
|
|||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
|
@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception):
|
|||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
|
@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception):
|
|||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
|
||||
description = "Model Currently Not Support"
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
|
|
@ -20,10 +20,7 @@ class APIBasedExtensionRequestor:
|
|||
:param params: the request params
|
||||
:return: the response json
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(self.api_key)
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)}
|
||||
|
||||
url = self.api_endpoint
|
||||
|
||||
|
@ -32,20 +29,17 @@ class APIBasedExtensionRequestor:
|
|||
proxies = None
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxies = {
|
||||
'http': dify_config.SSRF_PROXY_HTTP_URL,
|
||||
'https': dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
"http": dify_config.SSRF_PROXY_HTTP_URL,
|
||||
"https": dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method='POST',
|
||||
method="POST",
|
||||
url=url,
|
||||
json={
|
||||
'point': point.value,
|
||||
'params': params
|
||||
},
|
||||
json={"point": point.value, "params": params},
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
proxies=proxies
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
|
@ -53,9 +47,8 @@ class APIBasedExtensionRequestor:
|
|||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError("request error, status_code: {}, content: {}".format(
|
||||
response.status_code,
|
||||
response.text[:100]
|
||||
))
|
||||
raise ValueError(
|
||||
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
|
|
@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map
|
|||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = 'moderation'
|
||||
EXTERNAL_DATA_TOOL = 'external_data_tool'
|
||||
MODERATION = "moderation"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
|
@ -41,12 +41,12 @@ class Extensible:
|
|||
position_map = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith('__'):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
|
@ -58,21 +58,21 @@ class Extensible:
|
|||
# in the front-end page and business logic, there are special treatments.
|
||||
builtin = False
|
||||
position = None
|
||||
if '__builtin__' in file_names:
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
with open(builtin_file_path, encoding="utf-8") as f:
|
||||
position = int(f.read().strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
|
@ -91,25 +91,29 @@ class Extensible:
|
|||
|
||||
json_data = {}
|
||||
if not builtin:
|
||||
if 'schema.json' not in file_names:
|
||||
if "schema.json" not in file_names:
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions.append(ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
name=extension_name,
|
||||
label=json_data.get('label'),
|
||||
form_schema=json_data.get('form_schema'),
|
||||
builtin=builtin,
|
||||
position=position
|
||||
))
|
||||
extensions.append(
|
||||
ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
name=extension_name,
|
||||
label=json_data.get("label"),
|
||||
form_schema=json_data.get("form_schema"),
|
||||
builtin=builtin,
|
||||
position=position,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||
)
|
||||
|
||||
return sorted_extensions
|
||||
|
|
|
@ -6,10 +6,7 @@ from core.moderation.base import Moderation
|
|||
class Extension:
|
||||
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
|
||||
|
||||
module_classes = {
|
||||
ExtensionModule.MODERATION: Moderation,
|
||||
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
|
||||
}
|
||||
module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool}
|
||||
|
||||
def init(self):
|
||||
for module, module_class in self.module_classes.items():
|
||||
|
|
|
@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == self.tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(self.variable))
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid".format(self.variable)
|
||||
)
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key)
|
||||
|
||||
try:
|
||||
# request api
|
||||
requestor = APIBasedExtensionRequestor(
|
||||
api_endpoint=api_based_extension.api_endpoint,
|
||||
api_key=api_key
|
||||
)
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
|
||||
except Exception as e:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
|
||||
self.variable,
|
||||
e
|
||||
))
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
|
||||
|
||||
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
|
||||
'app_id': self.app_id,
|
||||
'tool_variable': self.variable,
|
||||
'inputs': inputs,
|
||||
'query': query
|
||||
})
|
||||
response_json = requestor.request(
|
||||
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
||||
params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query},
|
||||
)
|
||||
|
||||
if 'result' not in response_json:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
|
||||
.format(self.variable))
|
||||
if "result" not in response_json:
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, error: result not found in response".format(
|
||||
self.variable
|
||||
)
|
||||
)
|
||||
|
||||
if not isinstance(response_json['result'], str):
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string"
|
||||
.format(self.variable))
|
||||
if not isinstance(response_json["result"], str):
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
|
||||
)
|
||||
|
||||
return response_json['result']
|
||||
return response_json["result"]
|
||||
|
|
|
@ -12,11 +12,14 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ExternalDataFetch:
|
||||
def fetch(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
def fetch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
@ -38,7 +41,7 @@ class ExternalDataFetch:
|
|||
app_id,
|
||||
tool,
|
||||
inputs,
|
||||
query
|
||||
query,
|
||||
)
|
||||
|
||||
futures[future] = tool
|
||||
|
@ -50,12 +53,15 @@ class ExternalDataFetch:
|
|||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
def _query_external_data_tool(self, flask_app: Flask,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str) -> tuple[Optional[str], Optional[str]]:
|
||||
def _query_external_data_tool(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Query external data tool.
|
||||
:param flask_app: flask app
|
||||
|
@ -72,17 +78,10 @@ class ExternalDataFetch:
|
|||
tool_config = external_data_tool.config
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
result = external_data_tool_factory.query(inputs=inputs, query=query)
|
||||
|
||||
return tool_variable, result
|
||||
|
|
|
@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension
|
|||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=variable,
|
||||
config=config
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel):
|
|||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
@ -28,9 +29,9 @@ class FileType(enum.Enum):
|
|||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
TOOL_FILE = 'tool_file'
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum):
|
|||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
@ -65,16 +67,16 @@ class FileVar(BaseModel):
|
|||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'__variant': self.__class__.__name__,
|
||||
'tenant_id': self.tenant_id,
|
||||
'type': self.type.value,
|
||||
'transfer_method': self.transfer_method.value,
|
||||
'url': self.preview_url,
|
||||
'remote_url': self.url,
|
||||
'related_id': self.related_id,
|
||||
'filename': self.filename,
|
||||
'extension': self.extension,
|
||||
'mime_type': self.mime_type,
|
||||
"__variant": self.__class__.__name__,
|
||||
"tenant_id": self.tenant_id,
|
||||
"type": self.type.value,
|
||||
"transfer_method": self.transfer_method.value,
|
||||
"url": self.preview_url,
|
||||
"remote_url": self.url,
|
||||
"related_id": self.related_id,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"mime_type": self.mime_type,
|
||||
}
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
|
@ -86,7 +88,7 @@ class FileVar(BaseModel):
|
|||
if self.type == FileType.IMAGE:
|
||||
text = f'![{self.filename or ""}]({preview_url})'
|
||||
else:
|
||||
text = f'[{self.filename or preview_url}]({preview_url})'
|
||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
||||
|
||||
return text
|
||||
|
||||
|
@ -115,28 +117,29 @@ class FileVar(BaseModel):
|
|||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||
if image_config.get("detail") == "high"
|
||||
else ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
from models.model import UploadFile
|
||||
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
return UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
extension = self.extension
|
||||
# add sign url
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension)
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=extension
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS
|
|||
|
||||
|
||||
class MessageFileParser:
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||
def validate_and_transform_files_arg(
|
||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
||||
) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
|
@ -30,22 +30,22 @@ class MessageFileParser:
|
|||
"""
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError('Invalid file format, must be dict')
|
||||
if not file.get('type'):
|
||||
raise ValueError('Missing file type')
|
||||
FileType.value_of(file.get('type'))
|
||||
if not file.get('transfer_method'):
|
||||
raise ValueError('Missing file transfer method')
|
||||
FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get('url'):
|
||||
raise ValueError('Missing file url')
|
||||
if not file.get('url').startswith('http'):
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
|
||||
raise ValueError('Missing file tool_file_id')
|
||||
raise ValueError("Invalid file format, must be dict")
|
||||
if not file.get("type"):
|
||||
raise ValueError("Missing file type")
|
||||
FileType.value_of(file.get("type"))
|
||||
if not file.get("transfer_method"):
|
||||
raise ValueError("Missing file transfer method")
|
||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get("url"):
|
||||
raise ValueError("Missing file url")
|
||||
if not file.get("url").startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
||||
raise ValueError("Missing file upload_file_id")
|
||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
||||
raise ValueError("Missing file tool_file_id")
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
@ -62,17 +62,17 @@ class MessageFileParser:
|
|||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config['number_limits']:
|
||||
if len(files) > image_config["number_limits"]:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config['transfer_methods']:
|
||||
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
|
||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f'Invalid file type: {file_obj.type}')
|
||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
|
@ -81,18 +81,21 @@ class MessageFileParser:
|
|||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS)
|
||||
).first())
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError('Invalid upload file')
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
|
@ -113,8 +116,9 @@ class MessageFileParser:
|
|||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]:
|
||||
def _to_file_objs(
|
||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
||||
) -> dict[FileType, list[FileVar]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
|
@ -152,23 +156,23 @@ class MessageFileParser:
|
|||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config
|
||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=None,
|
||||
related_id=file.get('tool_file_id'),
|
||||
extra_config=file_extra_config
|
||||
related_id=file.get("tool_file_id"),
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
return FileVar(
|
||||
|
@ -178,7 +182,7 @@ class MessageFileParser:
|
|||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
related_id=file.upload_file_id or None,
|
||||
extra_config=file_extra_config
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
|
@ -190,17 +194,17 @@ class MessageFileParser:
|
|||
def is_s3_presigned_url(url):
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
if 'amazonaws.com' not in parsed_url.netloc:
|
||||
if "amazonaws.com" not in parsed_url.netloc:
|
||||
return False
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
required_params = ['Signature', 'Expires']
|
||||
required_params = ["Signature", "Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params['Expires'][0].isdigit():
|
||||
if not query_params["Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params['Signature'][0]
|
||||
if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature):
|
||||
signature = query_params["Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
tool_file_manager = {
|
||||
'manager': None
|
||||
}
|
||||
tool_file_manager = {"manager": None}
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> 'ToolFileManager':
|
||||
return tool_file_manager['manager']
|
||||
def get_tool_file_manager() -> "ToolFileManager":
|
||||
return tool_file_manager["manager"]
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||
from configs import dify_config
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
|
@ -22,18 +22,18 @@ class UploadFileParser:
|
|||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url:
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f'File not found: {upload_file.key}')
|
||||
logging.error(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
|
@ -44,7 +44,7 @@ class UploadFileParser:
|
|||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview'
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
|
|
|
@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CodeExecutionResponse(BaseModel):
|
||||
class Data(BaseModel):
|
||||
stdout: Optional[str] = None
|
||||
|
@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel):
|
|||
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
PYTHON3 = 'python3'
|
||||
JINJA2 = 'jinja2'
|
||||
JAVASCRIPT = 'javascript'
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
|
@ -45,63 +47,65 @@ class CodeExecutor:
|
|||
}
|
||||
|
||||
code_language_to_running_language = {
|
||||
CodeLanguage.JAVASCRIPT: 'nodejs',
|
||||
CodeLanguage.JAVASCRIPT: "nodejs",
|
||||
CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
|
||||
CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
|
||||
}
|
||||
|
||||
supported_dependencies_languages: set[CodeLanguage] = {
|
||||
CodeLanguage.PYTHON3
|
||||
}
|
||||
supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3}
|
||||
|
||||
@classmethod
|
||||
def execute_code(cls,
|
||||
language: CodeLanguage,
|
||||
preload: str,
|
||||
code: str) -> str:
|
||||
def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run'
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
|
||||
}
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
data = {
|
||||
'language': cls.code_language_to_running_language.get(language),
|
||||
'code': code,
|
||||
'preload': preload,
|
||||
'enable_network': True
|
||||
"language": cls.code_language_to_running_language.get(language),
|
||||
"code": code,
|
||||
"preload": preload,
|
||||
"enable_network": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers,
|
||||
timeout=Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None))
|
||||
response = post(
|
||||
str(url),
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None,
|
||||
),
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionException('Code execution service is unavailable')
|
||||
raise CodeExecutionException("Code execution service is unavailable")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running')
|
||||
raise Exception(
|
||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
|
||||
' please check if the sandbox service is running.'
|
||||
f' ( Error: {str(e)} )')
|
||||
raise CodeExecutionException(
|
||||
"Failed to execute code, which is likely a network issue,"
|
||||
" please check if the sandbox service is running."
|
||||
f" ( Error: {str(e)} )"
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
raise CodeExecutionException('Failed to parse response')
|
||||
raise CodeExecutionException("Failed to parse response")
|
||||
|
||||
if (code := response.get('code')) != 0:
|
||||
if (code := response.get("code")) != 0:
|
||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
|
@ -109,7 +113,7 @@ class CodeExecutor:
|
|||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
return response.data.stdout or ''
|
||||
return response.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
|
||||
|
@ -122,7 +126,7 @@ class CodeExecutor:
|
|||
"""
|
||||
template_transformer = cls.code_template_transformers.get(language)
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f'Unsupported language {language}')
|
||||
raise CodeExecutionException(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
|
|
|
@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel):
|
|||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
},
|
||||
{
|
||||
"variable": "arg2",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
|
||||
"code_language": cls.get_language(),
|
||||
"code": cls.get_default_code(),
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
}
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider):
|
|||
result: arg1 + arg2
|
||||
}
|
||||
}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
|
|
@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
|||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
|
||||
console.log(result)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
return runner_script
|
||||
|
|
|
@ -10,8 +10,6 @@ class Jinja2Formatter:
|
|||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=inputs
|
||||
)
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||
|
||||
return result['result']
|
||||
return result["result"]
|
||||
|
|
|
@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
|||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'result': cls.extract_result_str_from_response(response)
|
||||
}
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
|
|
|
@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider):
|
|||
return {
|
||||
"result": arg1 + arg2,
|
||||
}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user