chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang 2024-09-10 17:00:20 +08:00 committed by GitHub
parent 178730266d
commit 2cf1187b32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
724 changed files with 21180 additions and 21123 deletions

View File

@ -1 +1 @@
import core.moderation.base
import core.moderation.base

View File

@ -25,17 +25,19 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager
# check model mode
if 'Observation' not in app_generate_entity.model_conf.stop:
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation')
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
observation="",
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ''
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(
scratchpad.action.action_input)
final_answer = json.dumps(scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint=''
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[]
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, ToolInvokeMeta]:
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
continue
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
format assistant message
"""
message = ''
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content='')
assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages

View File

@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ''
assistant_prompt = ""
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}"
# join all messages
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]
return [UserPromptMessage(content=prompt)]

View File

@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary.
"""
return {
'action': self.action_name,
'action_input': self.action_input,
"action": self.action_name,
"action_input": self.action_input,
}
agent_response: Optional[str] = None
@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final.
"""
return self.action is None or (
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""
Agent Entity.
@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str
model: str

View File

@ -24,11 +24,9 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ''
response = ""
# save tool call names and inputs
tool_call_names = ''
tool_call_inputs = ''
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content
if not result.message.content:
result.message.content = ''
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0,
message=result.message,
usage=result.usage,
)
),
)
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls=[
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage
llm_usage=current_llm_usage,
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# call tools
tool_responses = []
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict()
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response['tool_response'] is not None:
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
content=tool_response["tool_response"],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
thought=None,
tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response['tool_call_name']: tool_response['tool_response']
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer=None,
messages_ids=message_file_ids
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
Initialize system message
"""
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
memory=self.memory,
).get_prompt()
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
@ -22,7 +23,7 @@ class CotAgentOutputParser:
action = action[0]
for key, value in action.items():
if 'input' in key.lower():
if "input" in key.lower():
action_input = value
else:
action_name = value
@ -33,37 +34,37 @@ class CotAgentOutputParser:
action_input=action_input,
)
else:
return json_str or ''
return json_str or ""
except:
return json_str or ''
return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ''
code_block_cache = ""
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
got_json = False
action_cache = ''
action_str = 'action:'
action_cache = ""
action_str = "action:"
action_idx = 0
thought_cache = ''
thought_str = 'thought:'
thought_cache = ""
thought_str = "thought:"
thought_idx = 0
for response in llm_response:
if response.delta.usage:
usage_dict['usage'] = response.delta.usage
usage_dict["usage"] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
@ -72,24 +73,24 @@ class CotAgentOutputParser:
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
if delta == '`':
if delta == "`":
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
code_block_cache = ""
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@ -97,7 +98,7 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
@ -105,18 +106,18 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ''
action_cache = ""
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@ -124,7 +125,7 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
@ -132,31 +133,31 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ''
thought_cache = ""
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
code_block_cache = ""
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
if delta == "{":
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
elif delta == "}":
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@ -172,12 +173,12 @@ class CotAgentOutputParser:
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
yield delta.replace("`", "")
index += steps
@ -186,4 +187,3 @@ class CotAgentOutputParser:
if json_cache:
yield parse_action(json_cache)

View File

@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
"english": {
"chat": {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}
}

View File

@ -26,34 +26,24 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict,
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert(config=config_dict)
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
return additional_features

View File

@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
if sensitive_word_avoidance_dict.get('enabled'):
if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
)
else:
return None
@classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {
"enabled": False
}
config["sensitive_word_avoidance"] = {"enabled": False}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
return config, ["sensitive_word_avoidance"]

View File

@ -12,67 +12,70 @@ class AgentConfigManager:
:param config: model config args
"""
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode']:
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
elif agent_strategy == "cot" or agent_strategy == "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config['model']['provider'] == 'openai':
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool.get('tool_parameters', {})
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
"react_router",
"router",
]:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
)
return AgentEntity(
provider=config['model']['provider'],
model=config['model']['name'],
provider=config["model"]["provider"],
model=config["model"]["name"],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
max_iteration=agent_dict.get("max_iteration", 5),
)
return None

View File

@ -15,39 +15,38 @@ class DatasetConfigManager:
:param config: model config args
"""
dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
if "datasets" in config.get("dataset_configs", {}):
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
for dataset in datasets.get('datasets', []):
for dataset in datasets.get("datasets", []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
if len(keys) == 0 or keys[0] != "dataset":
continue
dataset = dataset['dataset']
dataset = dataset["dataset"]
if 'enabled' not in dataset or not dataset['enabled']:
if "enabled" not in dataset or not dataset["enabled"]:
continue
dataset_id = dataset.get('id', None)
dataset_id = dataset.get("id", None)
if dataset_id:
dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
if (
"agent_mode" in config
and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
if key != "dataset":
continue
tool_item = tool[key]
@ -55,30 +54,28 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_id = tool_item["id"]
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get("dataset_configs")
else:
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
dataset_configs = {"retrieval_model": "multiple"}
query_variable = config.get("dataset_query_variable")
if dataset_configs['retrieval_model'] == 'single':
if dataset_configs["retrieval_model"] == "single":
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
)
)
dataset_configs["retrieval_model"]
),
),
)
else:
return DatasetEntity(
@ -86,15 +83,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
)
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
),
)
@classmethod
@ -111,13 +108,10 @@ class DatasetConfigManager:
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {'retrieval_model': 'single'}
config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
@ -125,8 +119,9 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@ -148,10 +143,7 @@ class DatasetConfigManager:
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {
"enabled": False,
"tools": []
}
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -188,7 +180,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if 'id' not in tool_item:
if "id" not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@ -25,9 +23,7 @@ class ModelConfigConverter:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
@ -38,8 +34,7 @@ class ModelConfigConverter:
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_config.model
model_type=ModelType.LLM, model=model_config.model
)
if model_credentials is None:
@ -51,8 +46,7 @@ class ModelConfigConverter:
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model,
model_type=ModelType.LLM
model=model_config.model, model_type=ModelType.LLM
)
if provider_model is None:
@ -69,24 +63,18 @@ class ModelConfigConverter:
# model config
completion_params = model_config.parameters
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args
"""
# model config
model_config = config.get('model')
model_config = config.get("model")
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get('completion_params')
completion_params = model_config.get("completion_params")
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model_config.get('mode')
model_mode = model_config.get("mode")
return ModelConfigEntity(
provider=config['model']['provider'],
model=config['model']['name'],
provider=config["model"]["provider"],
model=config["model"]["name"],
mode=model_mode,
parameters=completion_params,
stop=stop,
@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id
:param config: app model config args
"""
if 'model' not in config:
if "model" not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
@ -52,17 +52,16 @@ class ModelConfigManager:
# model.provider
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if 'name' not in config["model"]:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"],
model_type=ModelType.LLM
provider=config["model"]["provider"], model_type=ModelType.LLM
)
if not models:
@ -80,12 +79,12 @@ class ModelConfigManager:
# model.mode
if model_mode:
config['model']["mode"] = model_mode
config["model"]["mode"] = model_mode
else:
config['model']["mode"] = "completion"
config["model"]["mode"] = "completion"
# model.completion_params
if 'completion_params' not in config["model"]:
if "completion_params" not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params(
@ -101,7 +100,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type")
# stop
if 'stop' not in cp:
if "stop" not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")

View File

@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
"prompt": completion_prompt_config["prompt"]["text"],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
advanced_completion_prompt_template=advanced_completion_prompt_template,
)
@classmethod
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config['prompt_type'] not in prompt_type_vals:
if config["prompt_type"] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required "
"when prompt_type is advanced")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
)
model_mode_vals = [mode.value for mode in ModelMode]
if config['model']["mode"] not in model_mode_vals:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")

View File

@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
external_data_tools = config.get("external_data_tools", [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
variable=external_data_tool["variable"],
type=external_data_tool["type"],
config=external_data_tool["config"],
)
)
# variables and external_data_tools
for variables in config.get('user_input_form', []):
for variables in config.get("user_input_form", []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
if "config" not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable['variable'],
type=variable['type'],
config=variable['config']
variable=variable["variable"], type=variable["type"], config=variable["config"]
)
)
elif variable_type in [
@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
variable=variable.get("variable"),
description=variable.get("description"),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
)
)
@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
if 'label' not in form_item:
if "label" not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item:
if "variable" not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]:
if "required" not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if 'options' not in form_item or not form_item["options"]:
if "options" not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"]
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
return config, ["external_data_tools"]

View File

@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
mode: Optional[str] = None
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = 'simple'
ADVANCED = 'advanced'
SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> 'PromptType':
def value_of(cls, value: str) -> "PromptType":
"""
Get value of given mode.
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt type value {value}')
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = 'single'
MULTIPLE = 'multiple'
SINGLE = "single"
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy':
def value_of(cls, value: str) -> "RetrieveStrategy":
"""
Get value of given mode.
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid retrieve strategy value {value}')
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = .0
rerank_mode: Optional[str] = 'reranking_model'
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
"""
Tracing Config Entity.
"""
enabled: bool
tracing_provider: str
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel):
"""
Application Config Entity.
"""
tenant_id: str
app_id: str
app_mode: AppMode
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
"""
App Model Config From.
"""
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config'
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
ARGS = "args"
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
class EasyUIBasedAppConfig(AppConfig):
"""
Easy UI Based App Config Entity.
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str

View File

@ -13,21 +13,19 @@ class FileUploadConfigManager:
:param config: model config args
:param is_vision: if True, the feature is vision feature
"""
file_upload_dict = config.get('file_upload')
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get('image'):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = {
'number_limits': file_upload_dict['image']['number_limits'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
}
if is_vision:
image_config['detail'] = file_upload_dict['image']['detail']
image_config["detail"] = file_upload_dict["image"]["detail"]
return FileExtraConfig(
image_config=image_config
)
return FileExtraConfig(image_config=image_config)
return None
@ -49,21 +47,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods']
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ['remote_url', 'local_file']:
if method not in ["remote_url", "local_file"]:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get('more_like_this')
more_like_this_dict = config.get("more_like_this")
if more_like_this_dict:
if more_like_this_dict.get('enabled'):
if more_like_this_dict.get("enabled"):
more_like_this = True
return more_like_this
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {
"enabled": False
}
config["more_like_this"] = {"enabled": False}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")

View File

@ -1,5 +1,3 @@
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args
"""
# opening statement
opening_statement = config.get('opening_statement')
opening_statement = config.get("opening_statement")
# suggested questions
suggested_questions_list = config.get('suggested_questions')
suggested_questions_list = config.get("suggested_questions")
return opening_statement, suggested_questions_list

View File

@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource')
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
if retriever_resource_dict.get('enabled'):
if retriever_resource_dict.get("enabled"):
show_retrieve_source = True
return show_retrieve_source
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
:param config: app model config args
"""
if not config.get("retriever_resource"):
config["retriever_resource"] = {
"enabled": False
}
config["retriever_resource"] = {"enabled": False}
if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type")

View File

@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args
"""
speech_to_text = False
speech_to_text_dict = config.get('speech_to_text')
speech_to_text_dict = config.get("speech_to_text")
if speech_to_text_dict:
if speech_to_text_dict.get('enabled'):
if speech_to_text_dict.get("enabled"):
speech_to_text = True
return speech_to_text
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
:param config: app model config args
"""
if not config.get("speech_to_text"):
config["speech_to_text"] = {
"enabled": False
}
config["speech_to_text"] = {"enabled": False}
if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type")

View File

@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args
"""
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get('enabled'):
if suggested_questions_after_answer_dict.get("enabled"):
suggested_questions_after_answer = True
return suggested_questions_after_answer
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args
"""
if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = {
"enabled": False
}
config["suggested_questions_after_answer"] = {"enabled": False}
if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not \
config["suggested_questions_after_answer"]["enabled"]:
if (
"enabled" not in config["suggested_questions_after_answer"]
or not config["suggested_questions_after_answer"]["enabled"]
):
config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

View File

@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args
"""
text_to_speech = None
text_to_speech_dict = config.get('text_to_speech')
text_to_speech_dict = config.get("text_to_speech")
if text_to_speech_dict:
if text_to_speech_dict.get('enabled'):
if text_to_speech_dict.get("enabled"):
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
enabled=text_to_speech_dict.get("enabled"),
voice=text_to_speech_dict.get("voice"),
language=text_to_speech_dict.get("language"),
)
return text_to_speech
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
:param config: app model config args
"""
if not config.get("text_to_speech"):
config["text_to_speech"] = {
"enabled": False,
"voice": "",
"language": ""
}
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type")

View File

@ -1,4 +1,3 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
stream=stream,
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
raise ValueError("node_id is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
conversation_id=None,
inputs={},
query='',
query="",
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False
},
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
stream=stream
stream=stream,
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
def _generate(
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
is_first_conversation = True
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread.start()
@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
message=message,
)
runner.run()
@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '')
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
self.msg_text = ""
except Exception as e:
self.logger.warning(e)

View File

@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
:param application_generate_entity: application generate entity
@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id,
)
except ModerationException as e:
self._complete_with_stream_output(
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
if annotation_reply:
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
self._complete_with_stream_output(
text=annotation_reply.content,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)
self._publish_event(QueueStopEvent(stopped_by=stopped_by))

View File

@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras['metadata'] = stream_response.metadata
extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
# Save message
@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
extras=self._application_generate_entity.extras,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy()
extras["metadata"] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:

View File

@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
"""
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
:param app_model: app model
@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
agent=AgentConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args
"""
if not config.get("agent_mode"):
config["agent_mode"] = {
"enabled": False,
"tools": []
}
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in
list(PlanningStrategy.__members__.values())]:
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset":
if 'id' not in tool_item:
if "id" not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not stream:
raise ValueError('Agent Chat App does not support blocking mode')
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
override_config_dict=override_model_config_dict,
)
# get tracing instance
@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AgentChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self, flask_app: Flask,
self,
flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
"""
def run(
self, application_generate_entity: AgentChatAppGenerateEntity,
self,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
memory = None
@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# moderation
@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# reorganize all inputs and template to prompt messages
@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent
# load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id)
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
model_instance=model_instance,
)
invoke_result = runner.run(
@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True
agent=True,
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id
).first()
tool_variables: ToolConversationVariables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str='[]',
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(**{
'conversation_id': db_variables.conversation_id,
'user_id': db_variables.user_id,
'tenant_id': db_variables.tenant_id,
'pool': db_variables.variables
})
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
message: Message) -> LLMUsage:
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0
all_answer_tokens = 0
@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

View File

@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@classmethod
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f'data: {chunk}\n\n'
yield f"data: {chunk}\n\n"
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f'data: {chunk}\n\n'
yield f"data: {chunk}\n\n"
return _generate_simple_response()
@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
# show_retrieve_source
if 'retriever_resources' in metadata:
metadata['retriever_resources'] = []
for resource in metadata['retriever_resources']:
metadata['retriever_resources'].append({
'segment_id': resource['segment_id'],
'position': resource['position'],
'document_name': resource['document_name'],
'score': resource['score'],
'content': resource['content'],
})
if "retriever_resources" in metadata:
metadata["retriever_resources"] = []
for resource in metadata["retriever_resources"]:
metadata["retriever_resources"].append(
{
"segment_id": resource["segment_id"],
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
}
)
# show annotation reply
if 'annotation_reply' in metadata:
del metadata['annotation_reply']
if "annotation_reply" in metadata:
del metadata["annotation_reply"]
# show usage
if 'usage' in metadata:
del metadata['usage']
if "usage" in metadata:
del metadata["usage"]
return metadata
@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
error_responses = {
ValueError: {'code': 'invalid_param', 'status': 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
"code": "provider_quota_exceeded",
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"status": 400,
},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
InvokeError: {'code': 'completion_request_error', 'status': 400}
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {"code": "completion_request_error", "status": 400},
}
# Determine the response based on the type of exception
@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
data = v
if data:
data.setdefault('message', getattr(e, 'description', str(e)))
data.setdefault("message", getattr(e, "description", str(e)))
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
}
return data

View File

@ -16,10 +16,10 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.variable} is required in input form')
raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
return var.default or ""
if (
var.type
in (
@ -34,7 +34,7 @@ class BaseAppGenerator:
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
if "." in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
@ -43,14 +43,14 @@ class BaseAppGenerator:
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
return user_input_value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):
return value.replace('\x00', '')
return value.replace("\x00", "")
return value

View File

@ -24,9 +24,7 @@ class PublishFrom(Enum):
class AppQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
if not user_id:
raise ValueError("user is required")
@ -34,9 +32,10 @@ class AppQueueManager:
self._user_id = user_id
self._invoke_from = invoke_from
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
f"{user_prefix}-{self._user_id}")
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue()
@ -66,8 +65,7 @@ class AppQueueManager:
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
)
if elapsed_time // 10 > last_ping_time:
@ -88,9 +86,7 @@ class AppQueueManager:
:param pub_from: publish from
:return:
"""
self.publish(QueueErrorEvent(
error=e
), pub_from)
self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
@ -122,8 +118,8 @@ class AppQueueManager:
if result is None:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@ -168,9 +164,11 @@ class AppQueueManager:
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)
class GenerateTaskStoppedException(Exception):

View File

@ -31,12 +31,15 @@ if TYPE_CHECKING:
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None) -> int:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
@ -49,18 +52,20 @@ class AppRunner:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None:
return -1
@ -75,36 +80,39 @@ class AppRunner:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]):
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None:
return -1
@ -112,27 +120,28 @@ class AppRunner:
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def organize_prompt_messages(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
:param context:
@ -152,60 +161,54 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
query=query if query else "",
files=files,
context=context,
memory=memory,
model_config=model_config
model_config=model_config,
)
else:
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate(
text=advanced_completion_prompt_template.prompt
)
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(
text=message.text,
role=message.role
))
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query if query else '',
query=query if query else "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(self, queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None) -> None:
def direct_output(
self,
queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
"""
Direct output
:param queue_manager: application queue manager
@ -222,17 +225,10 @@ class AppRunner:
chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
@ -242,15 +238,19 @@ class AppRunner:
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
usage=usage if usage else LLMUsage.empty_usage(),
),
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False) -> None:
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@ -260,21 +260,13 @@ class AppRunner:
:return:
"""
if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: AppQueueManager,
agent: bool) -> None:
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
"""
Handle invoke result direct
:param invoke_result: invoke result
@ -285,12 +277,13 @@ class AppRunner:
queue_manager.publish(
QueueMessageEndEvent(
llm_result=invoke_result,
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: AppQueueManager,
agent: bool) -> None:
def _handle_invoke_result_stream(
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@ -300,21 +293,13 @@ class AppRunner:
"""
model = None
prompt_messages = []
text = ''
text = ""
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish(
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
else:
queue_manager.publish(
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
text += result.delta.message.content
@ -331,25 +316,24 @@ class AppRunner:
usage = LLMUsage.empty_usage()
llm_result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=llm_result,
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def moderation_for_inputs(
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
@ -367,14 +351,17 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query if query else '',
query=query if query else "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager
trace_manager=app_generate_entity.trace_manager,
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
def check_hosting_moderation(
self,
application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
@ -384,8 +371,7 @@ class AppRunner:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity,
prompt_messages=prompt_messages
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
)
if moderation_result:
@ -393,18 +379,20 @@ class AppRunner:
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream
text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream,
)
return moderation_result
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
def fill_in_inputs_from_external_data_tools(
self,
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
"""
Fill in variable inputs from external data tools if exists.
@ -417,18 +405,12 @@ class AppRunner:
"""
external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id,
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
@ -440,9 +422,5 @@ class AppRunner:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)

View File

@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
"""
Chatbot App Config Entity.
"""
pass
class ChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
:param app_model: app model
@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception('override_config_dict is required when config_from is ARGS')
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# opening_statement
@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))

View File

@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
override_config_dict=override_model_config_dict,
)
# get tracing instance
@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return ChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
message=message,
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
Chat Application Runner
"""
def run(self, application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
def run(
self,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
memory = None
@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# moderation
@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# get context from datasets
@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
application_generate_entity.invoke_from,
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
files=files,
query=query,
context=context,
memory=memory
memory=memory,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
db.session.close()
@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
"""
Completion App Config Entity.
"""
pass
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
:param app_model: app model
@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {}
@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
conversation = None
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# get tracing instance
@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start()
@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
message=message,
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def generate_more_like_this(self, app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
def generate_more_like_this(
self,
app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
message = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message:
raise MessageNotExistsError()
@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_model_config = message.app_model_config
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict['model']
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
override_model_config_dict['model'] = model_dict
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
message.files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# init application generate entity
@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={}
extras={},
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start()
@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner
"""
def run(self, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message) -> None:
def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
# organize all inputs and template to prompt messages
@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
# moderation
@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# get context from datasets
@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
application_generate_entity.invoke_from,
)
dataset_config = app_config.dataset
@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,
message_id=message.id
message_id=message.id,
)
# reorganize all inputs and template to prompt messages
@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
context=context
context=context,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
db.session.close()
@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Handle response.
@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
try:
@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception(e)
raise e
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation:
def _get_conversation_by_user(
self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
) -> Conversation:
conversation_filter = [
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
Conversation.status == "normal",
]
if isinstance(user, Account):
@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
if conversation.status != "normal":
raise ConversationCompletedError()
return conversation
def _get_app_model_config(self, app_model: App,
conversation: Optional[Conversation] = None) \
-> AppModelConfig:
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return app_model_config
def _init_generate_records(self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
conversation: Optional[Conversation] = None) \
-> tuple[Conversation, Message]:
def _init_generate_records(
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Optional[Conversation] = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
from_source = "api"
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
from_source = "console"
account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
]:
override_model_configs = app_config.app_model_config_dict
# get conversation introduction
@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name='New conversation',
name="New conversation",
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency='USD',
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id
from_account_id=account_id,
)
db.session.add(message)
@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
belongs_to='user',
belongs_to="user",
url=file.url,
upload_file_id=file.related_id,
created_by_role=('account' if account_id else 'end_user'),
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError()
@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
message = db.session.query(Message).filter(Message.id == message_id).first()
return message

View File

@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None:
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
event=event,
)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
event=event,
)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueAdvancedChatMessageEndEvent):
if isinstance(
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()

View File

@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
"""
Workflow App Config Entity.
"""
pass
@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)

View File

@ -34,26 +34,28 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
) -> dict: ...
def generate(
@ -65,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
):
"""
Generate App response.
@ -79,26 +81,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
inputs = args["inputs"]
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@ -114,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager
trace_manager=trace_manager,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -125,18 +120,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
self, *,
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -154,17 +150,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode
app_mode=app_model.mode,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
@ -177,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream,
)
return WorkflowAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -199,16 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
raise ValueError("node_id is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@ -219,13 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False
},
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -235,14 +222,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
stream=stream
stream=stream,
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -259,7 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
workflow_thread_pool_id=workflow_thread_pool_id,
)
runner.run()
@ -267,14 +257,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
@ -283,14 +272,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
finally:
db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool = False) -> Union[
WorkflowAppBlockingResponse,
Generator[WorkflowAppStreamResponse, None, None]
]:
def _handle_response(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -306,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream
stream=stream,
)
try:

View File

@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
app_mode: str) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

View File

@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
@ -120,12 +119,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
thread_pool_id=self.workflow_thread_pool_id,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id,
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id,
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):

View File

@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
SystemVariableKey.USER_ID: user_id,
}
self._task_state = WorkflowTaskState()
@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user)
db.session.close()
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> WorkflowAppBlockingResponse:
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
)
finished_at=int(stream_response.data.finished_at),
),
)
return response
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[WorkflowAppStreamResponse, None, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
"""
To stream response.
:return:
@ -158,10 +160,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
@ -171,17 +170,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None,
trace_manager=trace_manager,
)
@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
@ -394,7 +383,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
@ -417,7 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log)
@ -431,8 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text)
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
)
return response

View File

@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
"""
Init graph
"""
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# init graph
graph = Graph.init(
graph_config=graph_config
)
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError('graph not found in workflow')
raise ValueError("graph not found in workflow")
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration
node_configs = [
node for node in graph_config.get('nodes', [])
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
]
graph_config['nodes'] = node_configs
graph_config["nodes"] = node_configs
node_ids = [node.get('id') for node in node_configs]
node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge for edge in graph_config.get('edges', [])
if (edge.get('source') is None or edge.get('source') in node_ids)
and (edge.get('target') is None or edge.get('target') in node_ids)
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config['edges'] = edge_configs
graph_config["edges"] = edge_configs
# init graph
graph = Graph.init(
graph_config=graph_config,
root_node_id=node_id
)
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError('graph not found in workflow')
raise ValueError("graph not found in workflow")
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get('id') == node_id:
if node.get("id") == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph')
raise ValueError("iteration node id not found in workflow graph")
# Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=iteration_node_config
graph_config=workflow.graph_dict, config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
)
return graph, variable_pool
@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(
QueueWorkflowFailedEvent(error=event.error)
)
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
in_iteration_id=event.in_iteration_id
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
metadata=event.metadata,
)
)
elif isinstance(event, IterationRunNextEvent):
@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -30,169 +30,145 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_event(
self,
event: GraphEngineEvent
) -> None:
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink')
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green')
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
)
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
event=event
)
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
event=event
)
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
event=event
)
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
event=event
)
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
event=event
)
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
event=event
)
self.on_workflow_iteration_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(
self,
event: NodeRunStartedEvent
) -> None:
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(
self,
event: NodeRunSucceededEvent
) -> None:
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='green')
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green')
color="green",
)
def on_workflow_node_execute_failed(
self,
event: NodeRunFailedEvent
) -> None:
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='red')
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
)
def on_node_text_chunk(
self,
event: NodeRunStreamChunkEvent
) -> None:
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
color = "red"
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
@ -201,43 +177,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
event: IterationRunStartedEvent
) -> None:
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(
self,
event: IterationRunNextEvent
) -> None:
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"
) -> None:
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f'{text_to_print}', end=end)
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""

View File

@ -15,13 +15,14 @@ class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
def value_of(cls, value: str) -> "InvokeFrom":
"""
Get value of given mode.
@ -31,7 +32,7 @@ class InvokeFrom(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str:
"""
@ -40,21 +41,22 @@ class InvokeFrom(Enum):
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
return "web_app"
elif self == InvokeFrom.DEBUGGER:
return 'dev'
return "dev"
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return 'api'
return "api"
return 'dev'
return "dev"
class ModelConfigWithCredentialsEntity(BaseModel):
"""
Model Config With Credentials Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
"""
App Generate Entity.
"""
task_id: str
# app config
@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
Chat Application Generate Entity.
"""
# app config
app_config: EasyUIBasedAppConfig
model_conf: ModelConfigWithCredentialsEntity
@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Completion Application Generate Entity.
"""
pass
@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Agent Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Advanced Chat Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
Workflow Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict

View File

@ -14,6 +14,7 @@ class QueueEvent(str, Enum):
"""
QueueEvent enum
"""
LLM_CHUNK = "llm_chunk"
TEXT_CHUNK = "text_chunk"
AGENT_MESSAGE = "agent_message"
@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel):
"""
QueueEvent abstract entity
"""
event: QueueEvent
@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent):
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str
node_id: str
@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent):
predecessor_node_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
output: Optional[Any] = None # output for the current iteration
@field_validator('output', mode='before')
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent):
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError('output must be a valid type')
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str
@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
in_iteration_id: Optional[str] = None
@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
"""
QueueAnnotationReplyEvent entity
"""
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowSucceededEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
QueueNodeStartedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str
@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
QueueNodeSucceededEvent entity
"""
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str
@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_FILE
message_file_id: str
@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event: QueueEvent = QueueEvent.ERROR
error: Any = None
@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event: QueueEvent = QueueEvent.PING
@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
"""
QueueStopEvent entity
"""
class StopBy(Enum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent):
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
"""
task_id: str
app_mode: str
event: AppQueueEvent
@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
"""
MessageQueueMessage entity
"""
message_id: str
conversation_id: str
@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage):
"""
WorkflowQueueMessage entity
"""
pass
@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str

View File

@ -12,6 +12,7 @@ class TaskState(BaseModel):
"""
TaskState entity
"""
metadata: dict = {}
@ -19,6 +20,7 @@ class EasyUITaskState(TaskState):
"""
EasyUITaskState entity
"""
llm_result: LLMResult
@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState):
"""
WorkflowTaskState entity
"""
answer: str = ""
@ -33,6 +36,7 @@ class StreamEvent(Enum):
"""
Stream event
"""
PING = "ping"
ERROR = "error"
MESSAGE = "message"
@ -60,6 +64,7 @@ class StreamResponse(BaseModel):
"""
StreamResponse entity
"""
event: StreamEvent
task_id: str
@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
"""
ErrorStreamResponse entity
"""
event: StreamEvent = StreamEvent.ERROR
err: Exception
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str
@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str
@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
"""
MessageEndStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = {}
@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
"""
MessageFileStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_FILE
id: str
type: str
@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
"""
MessageReplaceStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str
@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
"""
AgentThoughtStreamResponse entity
"""
event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str
position: int
@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
"""
AgentMessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.AGENT_MESSAGE
id: str
answer: str
@ -162,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
workflow_id: str
sequence_number: int
@ -182,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
workflow_id: str
sequence_number: int
@ -210,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@ -249,7 +266,7 @@ class NodeStartStreamResponse(StreamResponse):
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
},
}
@ -262,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@ -315,9 +333,9 @@ class NodeFinishStreamResponse(StreamResponse):
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
@ -328,6 +346,7 @@ class ParallelBranchStartStreamResponse(StreamResponse):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
@ -349,6 +368,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
@ -372,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@ -397,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@ -422,6 +444,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@ -454,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
"""
Data entity
"""
text: str
event: StreamEvent = StreamEvent.TEXT_CHUNK
@ -469,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
"""
Data entity
"""
text: str
event: StreamEvent = StreamEvent.TEXT_REPLACE
@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse):
"""
PingStreamResponse entity
"""
event: StreamEvent = StreamEvent.PING
@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel):
"""
AppStreamResponse entity
"""
stream_response: StreamResponse
@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
"""
ChatbotAppStreamResponse entity
"""
conversation_id: str
message_id: str
created_at: int
@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
"""
CompletionAppStreamResponse entity
"""
message_id: str
created_at: int
@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
"""
WorkflowAppStreamResponse entity
"""
workflow_run_id: Optional[str] = None
@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel):
"""
AppBlockingResponse entity
"""
task_id: str
def to_dict(self) -> dict:
@ -532,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
@ -552,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
mode: str
message_id: str
@ -571,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
workflow_id: str
status: str

View File

@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
@ -27,8 +25,9 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_record.id).first()
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
)
if not annotation_setting:
return None
@ -41,55 +40,50 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
embedding_provider_name, embedding_model_name, "annotation"
)
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique='high_quality',
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
documents = vector.search_by_vector(
query=query,
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
from_source = 'api'
from_source = "api"
else:
from_source = 'console'
from_source = "console"
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score)
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score,
)
return annotation
except Exception as e:
logger.warning(f'Query annotation failed, exception: {str(e)}.')
logger.warning(f"Query annotation failed, exception: {str(e)}.")
return None
return None

View File

@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
class HostingModerationFeature:
def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list[PromptMessage]) -> bool:
def check(
self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
@ -23,9 +24,6 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(
model_config,
text
)
moderation_result = moderation.check_moderation(model_config, text)
return moderation_result

View File

@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
@ -27,13 +27,13 @@ class RateLimit:
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
if hasattr(self, "initialized"):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.last_recalculate_time = float("-inf")
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
@ -46,7 +46,7 @@ class RateLimit:
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
@ -54,8 +54,11 @@ class RateLimit:
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
timeout_requests = [
k
for k, v in request_details.items()
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
@ -69,8 +72,10 @@ class RateLimit:
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
raise AppInvokeQuotaExceededError(
"Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests)
)
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
@ -116,5 +121,5 @@ class RateLimitGenerator:
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
if self.generator is not None and hasattr(self.generator, "close"):
self.generator.close()

View File

@ -25,25 +25,25 @@ from .variables import (
)
__all__ = [
'IntegerVariable',
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
'SegmentType',
'SegmentGroup',
'Segment',
'NoneSegment',
'NoneVariable',
'IntegerSegment',
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArraySegment',
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
]

View File

@ -28,12 +28,12 @@ from .variables import (
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if not mapping.get("name"):
raise VariableError("missing name")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result
@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
raise ValueError(f'not supported value {value}')
raise ValueError(f"not supported value {value}")

View File

@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if '.' in part and (value := variable_pool.get(part.split('.'))):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))

View File

@ -8,15 +8,15 @@ class SegmentGroup(Segment):
@property
def text(self):
return ''.join([segment.text for segment in self.value])
return "".join([segment.text for segment in self.value])
@property
def log(self):
return ''.join([segment.log for segment in self.value])
return "".join([segment.log for segment in self.value])
@property
def markdown(self):
return ''.join([segment.markdown for segment in self.value])
return "".join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@ -14,13 +14,13 @@ class Segment(BaseModel):
value_type: SegmentType
value: Any
@field_validator('value_type')
@field_validator("value_type")
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields['value_type'].default:
if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'")
return value
@ -50,15 +50,15 @@ class NoneSegment(Segment):
@property
def text(self) -> str:
return 'null'
return "null"
@property
def log(self) -> str:
return 'null'
return "null"
@property
def markdown(self) -> str:
return 'null'
return "null"
class StringSegment(Segment):
@ -76,24 +76,21 @@ class IntegerSegment(Segment):
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
@property
def text(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property
def log(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment):
@ -101,11 +98,11 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
if hasattr(item, "to_markdown"):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
return "\n".join(items)
class ArrayAnySegment(ArraySegment):
@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]

View File

@ -2,14 +2,14 @@ from enum import Enum
class SegmentType(str, Enum):
NONE = 'none'
NUMBER = 'number'
STRING = 'string'
SECRET = 'secret'
ARRAY_ANY = 'array[any]'
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
OBJECT = 'object'
NONE = "none"
NUMBER = "number"
STRING = "string"
SECRET = "secret"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
GROUP = 'group'
GROUP = "group"

View File

@ -23,11 +23,11 @@ class Variable(Segment):
"""
id: str = Field(
default='',
default="",
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
description: str = Field(default='', description='Description of the variable.')
description: str = Field(default="", description="Description of the variable.")
class StringVariable(StringSegment, Variable):
@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

View File

@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(self, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
e = event.error
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError('Incorrect API key provided')
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
err = e
else:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = 'error'
refetch_message.status = "error"
refetch_message.error = err_desc
db.session.commit()
@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
:return:
"""
if isinstance(e, QuotaExceededError):
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.")
return (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
message = getattr(e, 'description', str(e))
message = getattr(e, "description", str(e))
if not message:
message = 'Internal Server Error, please contact support.'
message = "Internal Server Error, please contact support."
return message
@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
:param e: exception
:return:
"""
return ErrorStreamResponse(
task_id=self._application_generate_entity.task_id,
err=e
)
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
"""
@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
return OutputModeration(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
queue_manager=self._queue_manager
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion(
completion=completion,
public_event=False
completion=completion, public_event=False
)
self._output_moderation_handler = None

View File

@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
]
def __init__(self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool) -> None:
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model=self._model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
usage=LLMUsage.empty_usage(),
)
)
self._conversation_name_generate_thread = None
def process(
self,
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Process generate task pipeline.
@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse
]:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
"""
Process blocking response.
:return:
@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {
'usage': jsonable_encoder(self._task_state.llm_result.usage)
}
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
return response
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
"""
To stream response.
:return:
@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield CompletionAppStreamResponse(
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
else:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
"""
Save message.
:return:
@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode,
self._task_state.llm_result.prompt_messages
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
if llm_result.message.content else ''
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
if llm_result.message.content
else ""
)
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
self._message.currency = usage.currency
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE,
conversation_id=self._conversation.id,
message_id=self._message.id
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
)
)
@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model = model_config.model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_instance.get_llm_num_tokens(
self._task_state.llm_result.prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_instance.get_llm_num_tokens(
[self._task_state.llm_result.message]
)
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
credentials = model_config.credentials
@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
model, credentials, prompt_tokens, completion_tokens
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
return AgentMessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)
db.session.close()
@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
tool=agent_thought.tool,
tool_labels=agent_thought.tool_labels,
tool_input=agent_thought.tool_input,
message_files=agent_thought.files
message_files=agent_thought.files,
)
return None
@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
),
)
), PublishFrom.TASK_PIPELINE
),
PublishFrom.TASK_PIPELINE,
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:

View File

@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
@ -49,15 +46,18 @@ class MessageCycleManage:
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id,
'query': query
})
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation.id,
"query": query,
},
)
thread.start()
@ -65,17 +65,10 @@ class MessageCycleManage:
return None
def _generate_conversation_name_worker(self,
flask_app: Flask,
conversation_id: str,
query: str):
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
return
@ -105,12 +98,9 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
return annotation
@ -124,7 +114,7 @@ class MessageCycleManage:
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
self._task_state.metadata["retriever_resources"] = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
@ -132,27 +122,23 @@ class MessageCycleManage:
:param event: event
:return:
"""
message_file = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()
)
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file_id = tool_file_id.split(".")[0]
# get extension
if '.' in message_file.url:
if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10:
extension = '.bin'
extension = ".bin"
else:
extension = '.bin'
extension = ".bin"
# add sign url to local file
if message_file.url.startswith('http'):
if message_file.url.startswith("http"):
url = message_file.url
else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
@ -161,8 +147,8 @@ class MessageCycleManage:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or 'user',
url=url
belongs_to=message_file.belongs_to or "user",
url=url,
)
return None
@ -174,11 +160,7 @@ class MessageCycleManage:
:param message_id: message id
:return:
"""
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
)
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
"""
@ -186,7 +168,4 @@ class MessageCycleManage:
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
answer=answer
)
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)

View File

@ -70,14 +70,14 @@ class WorkflowCycleManage:
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == 'conversation':
if key.value == "conversation":
continue
inputs[f'sys.{key.value}'] = value
inputs[f"sys.{key.value}"] = value
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= (
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
@ -185,20 +185,26 @@ class WorkflowCycleManage:
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
running_workflow_node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
)
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
@ -216,7 +222,9 @@ class WorkflowCycleManage:
return workflow_run
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
def _handle_node_execution_start(
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
@ -333,16 +341,16 @@ class WorkflowCycleManage:
created_by_account = workflow_run.created_by_account
if created_by_account:
created_by = {
'id': created_by_account.id,
'name': created_by_account.name,
'email': created_by_account.email,
"id": created_by_account.id,
"name": created_by_account.name,
"email": created_by_account.email,
}
else:
created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user:
created_by = {
'id': created_by_end_user.id,
'user': created_by_end_user.session_id,
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse(
@ -401,7 +409,7 @@ class WorkflowCycleManage:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras['icon'] = ToolManager.get_tool_icon(
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
@ -410,10 +418,10 @@ class WorkflowCycleManage:
return response
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
@ -424,7 +432,7 @@ class WorkflowCycleManage:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -452,13 +460,10 @@ class WorkflowCycleManage:
iteration_id=event.in_iteration_id,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
@ -476,15 +481,15 @@ class WorkflowCycleManage:
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
),
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
@ -501,18 +506,15 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
),
)
def _workflow_iteration_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
@ -534,10 +536,12 @@ class WorkflowCycleManage:
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
def _workflow_iteration_next_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
@ -559,10 +563,12 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
def _workflow_iteration_completed_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
@ -585,13 +591,13 @@ class WorkflowCycleManage:
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
@ -643,7 +649,7 @@ class WorkflowCycleManage:
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
@ -656,11 +662,10 @@ class WorkflowCycleManage:
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == workflow_run_id).first()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}')
raise Exception(f"Workflow run not found: {workflow_run_id}")
return workflow_run
@ -683,6 +688,6 @@ class WorkflowCycleManage:
)
if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}')
raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution
return workflow_node_execution

View File

@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
"red": "31;1",
}
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
color: Optional[str] = ""
current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None:
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
self.color = color or 'green'
self.color = color or "green"
self.current_loop = 1
def on_tool_start(
@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
tool_outputs: Sequence[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color=self.color)
@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
)
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(
self, thought: str
) -> None:
def on_agent_start(self, thought: str) -> None:
"""Run on agent start."""
if thought:
print_text("\n[on_agent_start] \nCurrent Loop: " + \
str(self.current_loop) + \
"\nThought: " + thought + "\n", color=self.color)
print_text(
"\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
color=self.color,
)
else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(
self, color: Optional[str] = None, **kwargs: Any
) -> None:
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
"""Run on agent end."""
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"

View File

@ -1,4 +1,3 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, queue_manager: AppQueueManager,
app_id: str,
message_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source="app",
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
),
created_by=self._user_id,
)
db.session.add(dataset_query)
@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self._user_id
position=item.get("position"),
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource),
PublishFrom.APPLICATION_MANAGER
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""

View File

@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider).first()
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
text_embeddings[i] = embedding.get_embedding()
else:
@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_texts), max_chunks):
batch_texts = embedding_queue_texts[i:i + max_chunks]
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts,
user=self._user
)
embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
for vector in embedding_result.embeddings:
try:
@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception as e:
logging.exception('Failed transform embedding: ', e)
logging.exception("Failed transform embedding: ", e)
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider,
)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error('Failed to embed documents: ', ex)
logger.error("Failed to embed documents: ", ex)
raise ex
return text_embeddings
@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text],
user=self._user
)
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to redis')
logging.exception("Failed to add embedding to redis")
return embedding_results

View File

@ -2,7 +2,7 @@ from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
IMAGE = "image"
@staticmethod
def value_of(value):
@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel):
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
LOW = "low"
HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW

View File

@ -12,6 +12,7 @@ class ModelStatus(Enum):
"""
Enum class for model status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel):
"""
Simple provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types
supported_model_types=provider_entity.supported_model_types,
)
@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
load_balancing_enabled: bool = False
@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel):
"""
Default model provider entity.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: DefaultModelProviderEntity

View File

@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any(len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
if (
any(
len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations
)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
if self.model_settings:
# check if model is disabled by admin
for model_setting in self.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
if model_setting.model_type == model_type and model_setting.model == model:
if not model_setting.enabled:
raise ValueError(f'Model {model} is disabled.')
raise ValueError(f"Model {model} is disabled.")
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
if (restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name):
copy_credentials['base_model_name'] = restrict_model.base_model_name
if (
restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name
):
copy_credentials["base_model_name"] = restrict_model.base_model_name
return copy_credentials
else:
@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
None
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
SystemConfigurationStatus.QUOTA_EXCEEDED
return (
SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
else SystemConfigurationStatus.QUOTA_EXCEEDED
)
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0)
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
if self.provider.provider_credential_schema
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
if self.provider.provider_credential_schema
else []
)
if provider_record:
@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
# fix origin data
if provider_record.encrypted_config:
if not provider_record.encrypted_config.startswith("{"):
original_credentials = {
"openai_api_key": provider_record.encrypted_config
}
original_credentials = {"openai_api_key": provider_record.encrypted_config}
else:
original_credentials = json.loads(provider_record.encrypted_config)
else:
@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider,
credentials=credentials
provider=self.provider.provider, credentials=credentials
)
for key, value in credentials.items():
@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True
is_valid=True,
)
db.session.add(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
# delete provider
if provider_record:
@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
def get_custom_model_credentials(
self, model_type: ModelType, model: str, obfuscated: bool = False
) -> Optional[dict]:
"""
Get custom model credentials.
@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
if self.provider.model_credential_schema
else [],
)
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> tuple[ProviderModel, dict]:
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
if self.provider.model_credential_schema
else []
)
if provider_model_record:
try:
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
original_credentials = (
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
)
except JSONDecodeError:
original_credentials = {}
@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in credentials.items():
@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True
is_valid=True,
)
db.session.add(provider_model_record)
db.session.commit()
@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
# delete provider model
if provider_model_record:
@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.enabled = True
@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True
enabled=True,
)
db.session.add(model_setting)
db.session.commit()
@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.enabled = False
@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False
enabled=False,
)
db.session.add(model_setting)
db.session.commit()
@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
return db.session.query(ProviderModelSetting) \
return (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).count()
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.count()
)
if load_balancing_config_count <= 1:
raise ValueError('Model load balancing configuration must be more than 1.')
raise ValueError("Model load balancing configuration must be more than 1.")
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.load_balancing_enabled = True
@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True
load_balancing_enabled=True,
)
db.session.add(model_setting)
db.session.commit()
@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.load_balancing_enabled = False
@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False
load_balancing_enabled=False,
)
db.session.add(model_setting)
db.session.commit()
@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider
).first()
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider,
)
.first()
)
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value
preferred_provider_type=provider_type.value,
)
db.session.add(preferred_model_provider)
@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
credential_form_schemas
)
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
# Obfuscate provider credentials
copy_credentials = credentials.copy()
@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance,
model_setting_map=model_setting_map
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance,
model_setting_map=model_setting_map
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
)
if only_active:
@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
def _get_system_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status
status=status,
)
)
@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
ConfigurateMethod.CUSTOMIZABLE_MODEL
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
continue
status = ModelStatus.ACTIVE
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status
status=status,
)
)
@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
def _get_custom_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled
load_balancing_enabled=load_balancing_enabled,
)
)
@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
continue
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
status = ModelStatus.ACTIVE
load_balancing_enabled = False
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled
load_balancing_enabled=load_balancing_enabled,
)
)
@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
def get_models(
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get available models.
@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
# pydantic configs
model_config = ConfigDict(arbitrary_types_allowed=True,
protected_namespaces=())
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())

View File

@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType
class QuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
CREDITS = 'credits'
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = 'active'
QUOTA_EXCEEDED = 'quota-exceeded'
UNSUPPORTED = 'unsupported'
ACTIVE = "active"
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel):
@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
credentials: dict
@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel):
"""
Model class for provider custom model configuration.
"""
model: str
model_type: ModelType
credentials: dict
@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
"""
Class for model load balancing configuration.
"""
id: str
name: str
credentials: dict
@ -93,6 +100,7 @@ class ModelSettings(BaseModel):
"""
Model class for model settings.
"""
model: str
model_type: ModelType
enabled: bool = True

View File

@ -3,6 +3,7 @@ from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
@ -11,6 +12,7 @@ class LLMError(Exception):
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
@ -28,6 +31,7 @@ class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
description = "App Invoke Quota Exceeded"
@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"
class InvokeRateLimitError(Exception):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"

View File

@ -20,10 +20,7 @@ class APIBasedExtensionRequestor:
:param params: the request params
:return: the response json
"""
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(self.api_key)
}
headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)}
url = self.api_endpoint
@ -32,20 +29,17 @@ class APIBasedExtensionRequestor:
proxies = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
'http': dify_config.SSRF_PROXY_HTTP_URL,
'https': dify_config.SSRF_PROXY_HTTPS_URL,
"http": dify_config.SSRF_PROXY_HTTP_URL,
"https": dify_config.SSRF_PROXY_HTTPS_URL,
}
response = requests.request(
method='POST',
method="POST",
url=url,
json={
'point': point.value,
'params': params
},
json={"point": point.value, "params": params},
headers=headers,
timeout=self.timeout,
proxies=proxies
proxies=proxies,
)
except requests.exceptions.Timeout:
raise ValueError("request timeout")
@ -53,9 +47,8 @@ class APIBasedExtensionRequestor:
raise ValueError("request connection error")
if response.status_code != 200:
raise ValueError("request error, status_code: {}, content: {}".format(
response.status_code,
response.text[:100]
))
raise ValueError(
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
)
return response.json()

View File

@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
EXTERNAL_DATA_TOOL = 'external_data_tool'
MODERATION = "moderation"
EXTERNAL_DATA_TOOL = "external_data_tool"
class ModuleExtension(BaseModel):
@ -41,12 +41,12 @@ class Extensible:
position_map = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
current_dir_path = os.path.dirname(current_path)
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith('__'):
if subdir_name.startswith("__"):
continue
subdir_path = os.path.join(current_dir_path, subdir_name)
@ -58,21 +58,21 @@ class Extensible:
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
if '__builtin__' in file_names:
if "__builtin__" in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, '__builtin__')
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f:
with open(builtin_file_path, encoding="utf-8") as f:
position = int(f.read().strip())
position_map[extension_name] = position
if (extension_name + '.py') not in file_names:
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
@ -91,25 +91,29 @@ class Extensible:
json_data = {}
if not builtin:
if 'schema.json' not in file_names:
if "schema.json" not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, 'schema.json')
json_path = os.path.join(subdir_path, "schema.json")
json_data = {}
if os.path.exists(json_path):
with open(json_path, encoding='utf-8') as f:
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
extensions.append(ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
))
extensions.append(
ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get("label"),
form_schema=json_data.get("form_schema"),
builtin=builtin,
position=position,
)
)
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)
return sorted_extensions

View File

@ -6,10 +6,7 @@ from core.moderation.base import Moderation
class Extension:
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
module_classes = {
ExtensionModule.MODERATION: Moderation,
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
}
module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool}
def init(self):
for module, module_class in self.module_classes.items():

View File

@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = self.config.get("api_based_extension_id")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == self.tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(self.variable))
raise ValueError(
"[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid".format(self.variable)
)
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=self.tenant_id,
token=api_based_extension.api_key
)
api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key)
try:
# request api
requestor = APIBasedExtensionRequestor(
api_endpoint=api_based_extension.api_endpoint,
api_key=api_key
)
requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
self.variable,
e
))
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
'app_id': self.app_id,
'tool_variable': self.variable,
'inputs': inputs,
'query': query
})
response_json = requestor.request(
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query},
)
if 'result' not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
.format(self.variable))
if "result" not in response_json:
raise ValueError(
"[External data tool] API query failed, variable: {}, error: result not found in response".format(
self.variable
)
)
if not isinstance(response_json['result'], str):
raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string"
.format(self.variable))
if not isinstance(response_json["result"], str):
raise ValueError(
"[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
)
return response_json['result']
return response_json["result"]

View File

@ -12,11 +12,14 @@ logger = logging.getLogger(__name__)
class ExternalDataFetch:
def fetch(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
def fetch(
self,
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
"""
Fill in variable inputs from external data tools if exists.
@ -38,7 +41,7 @@ class ExternalDataFetch:
app_id,
tool,
inputs,
query
query,
)
futures[future] = tool
@ -50,12 +53,15 @@ class ExternalDataFetch:
inputs.update(results)
return inputs
def _query_external_data_tool(self, flask_app: Flask,
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str) -> tuple[Optional[str], Optional[str]]:
def _query_external_data_tool(
self,
flask_app: Flask,
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str,
) -> tuple[Optional[str], Optional[str]]:
"""
Query external data tool.
:param flask_app: flask app
@ -72,17 +78,10 @@ class ExternalDataFetch:
tool_config = external_data_tool.config
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
result = external_data_tool_factory.query(inputs=inputs, query=query)
return tool_variable, result

View File

@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id,
app_id=app_id,
variable=variable,
config=config
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
@classmethod

View File

@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
IMAGE = 'image'
IMAGE = "image"
@staticmethod
def value_of(value):
@ -28,9 +29,9 @@ class FileType(enum.Enum):
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
TOOL_FILE = 'tool_file'
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum):
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = 'user'
ASSISTANT = 'assistant'
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
@ -65,16 +67,16 @@ class FileVar(BaseModel):
def to_dict(self) -> dict:
return {
'__variant': self.__class__.__name__,
'tenant_id': self.tenant_id,
'type': self.type.value,
'transfer_method': self.transfer_method.value,
'url': self.preview_url,
'remote_url': self.url,
'related_id': self.related_id,
'filename': self.filename,
'extension': self.extension,
'mime_type': self.mime_type,
"__variant": self.__class__.__name__,
"tenant_id": self.tenant_id,
"type": self.type.value,
"transfer_method": self.transfer_method.value,
"url": self.preview_url,
"remote_url": self.url,
"related_id": self.related_id,
"filename": self.filename,
"extension": self.extension,
"mime_type": self.mime_type,
}
def to_markdown(self) -> str:
@ -86,7 +88,7 @@ class FileVar(BaseModel):
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})'
else:
text = f'[{self.filename or preview_url}]({preview_url})'
text = f"[{self.filename or preview_url}]({preview_url})"
return text
@ -115,28 +117,29 @@ class FileVar(BaseModel):
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.related_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
upload_file = (
db.session.query(UploadFile)
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
.first()
)
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension
# add sign url
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension)
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None

View File

@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
user: Union[Account, EndUser]) -> list[FileVar]:
def validate_and_transform_files_arg(
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
"""
validate and transform files arg
@ -30,22 +30,22 @@ class MessageFileParser:
"""
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
raise ValueError('Missing file tool_file_id')
raise ValueError("Invalid file format, must be dict")
if not file.get("type"):
raise ValueError("Missing file type")
FileType.value_of(file.get("type"))
if not file.get("transfer_method"):
raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get("transfer_method"))
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get("url"):
raise ValueError("Missing file url")
if not file.get("url").startswith("http"):
raise ValueError("Invalid file url")
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError("Missing file upload_file_id")
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError("Missing file tool_file_id")
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
@ -62,17 +62,17 @@ class MessageFileParser:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}')
raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
@ -81,18 +81,21 @@ class MessageFileParser:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(IMAGE_EXTENSIONS)
).first())
upload_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
raise ValueError("Invalid upload file")
new_files.append(file_obj)
@ -113,8 +116,9 @@ class MessageFileParser:
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]:
def _to_file_objs(
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
"""
transform files to file objs
@ -152,23 +156,23 @@ class MessageFileParser:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config,
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=None,
related_id=file.get('tool_file_id'),
extra_config=file_extra_config
related_id=file.get("tool_file_id"),
extra_config=file_extra_config,
)
else:
return FileVar(
@ -178,7 +182,7 @@ class MessageFileParser:
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
related_id=file.upload_file_id or None,
extra_config=file_extra_config
extra_config=file_extra_config,
)
def _check_image_remote_url(self, url):
@ -190,17 +194,17 @@ class MessageFileParser:
def is_s3_presigned_url(url):
try:
parsed_url = urlparse(url)
if 'amazonaws.com' not in parsed_url.netloc:
if "amazonaws.com" not in parsed_url.netloc:
return False
query_params = parse_qs(parsed_url.query)
required_params = ['Signature', 'Expires']
required_params = ["Signature", "Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params['Expires'][0].isdigit():
if not query_params["Expires"][0].isdigit():
return False
signature = query_params['Signature'][0]
if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature):
signature = query_params["Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
except Exception:

View File

@ -1,8 +1,7 @@
tool_file_manager = {
'manager': None
}
tool_file_manager = {"manager": None}
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> 'ToolFileManager':
return tool_file_manager['manager']
def get_tool_file_manager() -> "ToolFileManager":
return tool_file_manager["manager"]

View File

@ -9,7 +9,7 @@ from typing import Optional
from configs import dify_config
from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
@ -22,18 +22,18 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url:
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
logging.error(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
@ -44,7 +44,7 @@ class UploadFileParser:
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview'
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()

View File

@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
class CodeExecutionException(Exception):
pass
class CodeExecutionResponse(BaseModel):
class Data(BaseModel):
stdout: Optional[str] = None
@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel):
class CodeLanguage(str, Enum):
PYTHON3 = 'python3'
JINJA2 = 'jinja2'
JAVASCRIPT = 'javascript'
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
class CodeExecutor:
@ -45,63 +47,65 @@ class CodeExecutor:
}
code_language_to_running_language = {
CodeLanguage.JAVASCRIPT: 'nodejs',
CodeLanguage.JAVASCRIPT: "nodejs",
CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
}
supported_dependencies_languages: set[CodeLanguage] = {
CodeLanguage.PYTHON3
}
supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3}
@classmethod
def execute_code(cls,
language: CodeLanguage,
preload: str,
code: str) -> str:
def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
"""
Execute code
:param language: code language
:param code: code
:return:
"""
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run'
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
headers = {
'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
}
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
data = {
'language': cls.code_language_to_running_language.get(language),
'code': code,
'preload': preload,
'enable_network': True
"language": cls.code_language_to_running_language.get(language),
"code": code,
"preload": preload,
"enable_network": True,
}
try:
response = post(str(url), json=data, headers=headers,
timeout=Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None))
response = post(
str(url),
json=data,
headers=headers,
timeout=Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None,
),
)
if response.status_code == 503:
raise CodeExecutionException('Code execution service is unavailable')
raise CodeExecutionException("Code execution service is unavailable")
elif response.status_code != 200:
raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running')
raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
)
except CodeExecutionException as e:
raise e
except Exception as e:
raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
' please check if the sandbox service is running.'
f' ( Error: {str(e)} )')
raise CodeExecutionException(
"Failed to execute code, which is likely a network issue,"
" please check if the sandbox service is running."
f" ( Error: {str(e)} )"
)
try:
response = response.json()
except:
raise CodeExecutionException('Failed to parse response')
raise CodeExecutionException("Failed to parse response")
if (code := response.get('code')) != 0:
if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response)
@ -109,7 +113,7 @@ class CodeExecutor:
if response.data.error:
raise CodeExecutionException(response.data.error)
return response.data.stdout or ''
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
@ -122,7 +126,7 @@ class CodeExecutor:
"""
template_transformer = cls.code_template_transformers.get(language)
if not template_transformer:
raise CodeExecutionException(f'Unsupported language {language}')
raise CodeExecutionException(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)

View File

@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel):
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
"outputs": {"result": {"type": "string", "children": None}},
},
}

View File

@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider):
result: arg1 + arg2
}
}
""")
"""
)

View File

@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
""")
"""
)
return runner_script

View File

@ -10,8 +10,6 @@ class Jinja2Formatter:
:param inputs: inputs
:return:
"""
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=inputs
)
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
return result['result']
return result["result"]

View File

@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
:param response: response
:return:
"""
return {
'result': cls.extract_result_str_from_response(response)
}
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def get_runner_script(cls) -> str:

View File

@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider):
return {
"result": arg1 + arg2,
}
""")
"""
)

Some files were not shown because too many files have changed in this diff Show More