mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
5da0182800
commit
dabfd74622
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.7.3",
|
||||
default="0.8.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
@ -4,12 +4,10 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
|
@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
|
@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
|
@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
conversation=None,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True):
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
# db.session.refresh(conversation)
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
|
@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
|
|
@ -1,49 +1,67 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, Workflow
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
|
@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
|
@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
self._complete_with_stream_output(
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
) -> bool:
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
|
@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
self._publish_event(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
|
||||
)
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
def _complete_with_stream_output(self,
|
||||
text: str,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
)
|
||||
)
|
||||
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(
|
||||
QueueStopEvent(stopped_by=stopped_by)
|
||||
)
|
||||
|
|
|
@ -2,9 +2,8 @@ import json
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
|
@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
|
@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: AdvancedChatTaskState
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
|
@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
|
@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self):
|
||||
|
@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
|
@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if (message.event
|
||||
and getattr(message.event, 'metadata', None)
|
||||
and message.event.metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
elif (hasattr(message.event, 'execution_metadata')
|
||||
and message.event.execution_metadata
|
||||
and message.event.execution_metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
|
||||
self._refetch_message()
|
||||
self._message.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.commit()
|
||||
|
@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
# reset current route position to 0
|
||||
self._task_state.current_stream_generate_state.current_route_position = 0
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
# stream outputs when node finished
|
||||
generator = self._generate_stream_outputs_when_node_finished()
|
||||
if generator:
|
||||
yield from generator
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
if workflow_run:
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if workflow_run and graph_runtime_state:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error=event.get_stop_reason(),
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
else:
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._refetch_message()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
|
@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key) if outputs else None
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._message_to_stream_response(text, self._message.id)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
|
@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
def _refetch_message(self) -> None:
|
||||
"""
|
||||
Refetch message.
|
||||
:return:
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if message:
|
||||
self._message = message
|
||||
|
|
|
@ -1,203 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
|
@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
|
|||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
@ -347,7 +347,7 @@ class AppRunner:
|
|||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
|
@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, app_model: App,
|
||||
self, *,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
|
@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
|
@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
|
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
|
@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
|
|
|
@ -4,46 +4,61 @@ from typing import Optional, cast
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner:
|
||||
class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
|
@ -53,80 +68,64 @@ class WorkflowAppRunner:
|
|||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
|
@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
|
|||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
@ -52,8 +52,8 @@ from models.workflow import (
|
|||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
|
@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState(
|
||||
iteration_nested_node_ids=[]
|
||||
)
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp())
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
workflow_run_id = None
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
|
@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, trace_manager=trace_manager
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
|
@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._text_replace_to_stream_response(event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
|
@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# not save log for debugging
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
|
||||
created_by=self._user.id,
|
||||
)
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
|
||||
"""
|
||||
Text replace to stream response.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
return TextReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
|
|
@ -1,200 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
379
api/core/app/apps/workflow_app_runner.py
Normal file
379
api/core/app/apps/workflow_app_runner.py
Normal file
|
@ -0,0 +1,379 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
|
@ -1,10 +1,24 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
|
@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_started]", color='pink')
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color='pink')
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color='green')
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
event=event
|
||||
)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_failed]", color='red')
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
||||
self.print_text(f"Node ID: {node_id}", color='yellow')
|
||||
self.print_text(f"Type: {node_type.value}", color='yellow')
|
||||
self.print_text(f"Index: {node_run_index}", color='yellow')
|
||||
if predecessor_node_id:
|
||||
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='yellow')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='yellow')
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
||||
self.print_text(f"Node ID: {node_id}", color='green')
|
||||
self.print_text(f"Type: {node_type.value}", color='green')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
|
||||
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
|
||||
color='green')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='green')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='green')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='green')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color='green')
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
||||
self.print_text(f"Node ID: {node_id}", color='red')
|
||||
self.print_text(f"Type: {node_type.value}", color='red')
|
||||
self.print_text(f"Error: {error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='red')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='red')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='red')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='red')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='red')
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if not self.current_node_id or self.current_node_id != node_id:
|
||||
self.current_node_id = node_id
|
||||
self.print_text('\n[on_node_text_chunk]')
|
||||
self.print_text(f"Node ID: {node_id}")
|
||||
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
self.print_text(text, color="pink", end="")
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {node_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[dict]) -> None:
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self.print_text("\n[on_workflow_event]", color='blue')
|
||||
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
|
@ -31,6 +33,9 @@ class QueueEvent(str, Enum):
|
|||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
|
@ -38,7 +43,7 @@ class QueueEvent(str, Enum):
|
|||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent entity
|
||||
QueueEvent abstract entity
|
||||
"""
|
||||
event: QueueEvent
|
||||
|
||||
|
@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel):
|
|||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
QueueIterationStartEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
|
@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
outputs: dict
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
|
@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
|
@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
|
@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
|||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
|
@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
|||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
|
@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
execution_metadata: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent):
|
|||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
def get_stop_reason(self) -> str:
|
||||
"""
|
||||
To stop reason
|
||||
"""
|
||||
reason_mapping = {
|
||||
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||
}
|
||||
|
||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage entity
|
||||
QueueMessage abstract entity
|
||||
"""
|
||||
task_id: str
|
||||
app_mode: str
|
||||
|
@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
WorkflowQueueMessage entity
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
error: str
|
||||
|
|
|
@ -3,40 +3,11 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
current_route_position: int = 0
|
||||
|
||||
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution_id: str
|
||||
node_type: NodeType
|
||||
start_at: float
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
|
@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState):
|
|||
"""
|
||||
answer: str = ""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
"""
|
||||
|
@ -97,6 +47,8 @@ class StreamEvent(Enum):
|
|||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
|
@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
inputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {}
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
|
@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": []
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
|
@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
|
@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
|
@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[dict] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
|||
"""
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
workflow_run_id: str
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
class AppBlockingResponse(BaseModel):
|
||||
|
@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parent_iteration_id: Optional[str] = None
|
||||
iteration_id: str
|
||||
current_index: int
|
||||
iteration_steps_boundary: list[int] = None
|
||||
node_execution_id: str
|
||||
started_at: float
|
||||
inputs: dict = None
|
||||
total_tokens: int = 0
|
||||
node_data: BaseNodeData
|
||||
|
||||
current_iterations: dict[str, Data] = None
|
||||
|
|
|
@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline:
|
|||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = 'error'
|
||||
message.error = err_desc
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
|
||||
db.session.commit()
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = 'error'
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return err
|
||||
|
||||
def _error_to_desc(cls, e: Exception) -> str:
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
|
|
|
@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
|
|||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
|
@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
|
|||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
@ -36,7 +35,7 @@ class MessageCycleManage:
|
|||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
@ -45,6 +44,9 @@ class MessageCycleManage:
|
|||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
|
@ -52,7 +54,7 @@ class MessageCycleManage:
|
|||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
|
@ -75,6 +77,9 @@ class MessageCycleManage:
|
|||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
|
@ -121,34 +126,13 @@ class MessageCycleManage:
|
|||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
|
|
|
@ -1,33 +1,41 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
NodeExecutionInfo,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
@ -41,54 +49,56 @@ from models.workflow import (
|
|||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
db.session.query(db.func.max(WorkflowRun.sequence_number))
|
||||
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
|
||||
.filter(WorkflowRun.app_id == self._workflow.app_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**user_inputs}
|
||||
for key, value in (system_inputs or {}).items():
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == 'conversation':
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from= (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(inputs),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
|
@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||
).all()
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
|
@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:param event: queue node succeeded event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
|
@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp())
|
||||
)
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
def _workflow_finish_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
|
@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
'id': created_by_account.id,
|
||||
'name': created_by_account.name,
|
||||
'email': created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
'id': created_by_end_user.id,
|
||||
'user': created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
|
@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
|
@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp())
|
||||
)
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
|
@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run started event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run succeeded or failed event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_workflow_start(self) -> WorkflowRun:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs=self._workflow_system_variables
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration start event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration next event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
current_iteration_data = None
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if data.parent_iteration_id == None:
|
||||
current_iteration_data = data
|
||||
break
|
||||
|
||||
if current_iteration_data:
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration completed event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(
|
||||
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.',
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
if latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
|
||||
if (workflow_node_execution
|
||||
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=latest_node_execution_info.start_at,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
|
||||
outputs = workflow_node_execution.outputs
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
|
@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f'Workflow run not found: {workflow_run_id}')
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Refetch workflow node execution
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f'Workflow node execution not found: {node_execution_id}')
|
||||
|
||||
return workflow_node_execution
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
@ -1,290 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
|
@ -63,6 +63,39 @@ class LLMUsage(ModelUsage):
|
|||
latency=0.0
|
||||
)
|
||||
|
||||
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency
|
||||
)
|
||||
|
||||
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
|
|
|
@ -34,13 +34,13 @@ class OutputModeration(BaseModel):
|
|||
final_output: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def should_direct_output(self):
|
||||
def should_direct_output(self) -> bool:
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
def get_final_output(self) -> str:
|
||||
return self.final_output or ""
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
def append_new_token(self, token: str) -> None:
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
@ -18,6 +18,7 @@ class WorkflowTool(Tool):
|
|||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
|
||||
label: str
|
||||
|
||||
|
@ -57,6 +58,7 @@ class WorkflowTool(Tool):
|
|||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
|
|
@ -128,6 +128,7 @@ class ToolEngine:
|
|||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
|
@ -141,6 +142,7 @@ class ToolEngine:
|
|||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
|
|
@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool
|
|||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
@ -249,7 +248,7 @@ class ToolManager:
|
|||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
|
|
|
@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
|
|
|
@ -1,116 +1,15 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self) -> None:
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
Published event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
|||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ class NodeType(Enum):
|
|||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
|
@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum):
|
|||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
PARALLEL_ID = 'parallel_id'
|
||||
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
|
||||
PARENT_PARALLEL_ID = 'parent_parallel_id'
|
||||
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
@ -65,11 +70,32 @@ class NodeRunResult(BaseModel):
|
|||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
|
|
@ -2,6 +2,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
|
@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
|||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariableKey, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
class VariablePool(BaseModel):
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
description='Variables mapping',
|
||||
default=defaultdict(dict)
|
||||
)
|
||||
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
|
||||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description='User inputs',
|
||||
)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description='System variables',
|
||||
)
|
||||
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
conversation_variables: Sequence[Variable] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def val_model_after(self):
|
||||
"""
|
||||
Append system variables
|
||||
:return:
|
||||
"""
|
||||
# Add system variables to the variable pool
|
||||
self.system_variables = system_variables
|
||||
for key, value in system_variables.items():
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables:
|
||||
for var in self.environment_variables or []:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in conversation_variables or []:
|
||||
for var in self.conversation_variables or []:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
return self
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
|
@ -79,7 +89,7 @@ class VariablePool:
|
|||
v = factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
|
@ -97,7 +107,7 @@ class VariablePool:
|
|||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -118,7 +128,7 @@ class VariablePool:
|
|||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
|
@ -134,7 +144,19 @@ class VariablePool:
|
|||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
def remove_node(self, node_id: str, /):
|
||||
"""
|
||||
Remove all variables associated with a given node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.variable_dictionary.pop(node_id, None)
|
||||
|
|
|
@ -66,8 +66,7 @@ class WorkflowRunState:
|
|||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
self.workflow_node_runs = []
|
||||
self.current_iteration_state = None
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
|
||||
self.node_id = node_id
|
||||
self.node_type = node_type
|
||||
self.node_title = node_title
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_title} run failed: {error}")
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
|
|
0
api/core/workflow/graph_engine/__init__.py
Normal file
0
api/core/workflow/graph_engine/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,28 @@
|
|||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
|
@ -0,0 +1,32 @@
|
|||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
compare_result = all(group_result)
|
||||
|
||||
return compare_result
|
|
@ -0,0 +1,35 @@
|
|||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
163
api/core/workflow/graph_engine/entities/event.py
Normal file
163
api/core/workflow/graph_engine/entities/event.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
692
api/core/workflow/graph_engine/entities/graph.py
Normal file
692
api/core/workflow/graph_engine/entities/graph.py
Normal file
|
@ -0,0 +1,692 @@
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
end_stream_param: EndStreamParam = Field(
|
||||
...,
|
||||
description="end stream param"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(
|
||||
route=[root_node_id],
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=3,
|
||||
parent_parallel_id=parallel.parent_parallel_id
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping:
|
||||
leaf_node_ids.append(node_id)
|
||||
elif (len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(
|
||||
cls,
|
||||
route: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
|
||||
|
||||
new_route = route[:]
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = []
|
||||
condition_edge_mappings = {}
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if not condition_hash in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for _, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=parallel_branch_node_ids
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if (
|
||||
parent_parallel_parent_parallel
|
||||
and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id]
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(cls,
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: dict[str, list[str]]) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
parallel_start_node_id = None
|
||||
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
parallel_start_node_id = p_start_node_id
|
||||
return True
|
||||
|
||||
if not parallel_start_node_id:
|
||||
raise Exception("Parallel start node id not found")
|
||||
|
||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(
|
||||
cls,
|
||||
node1_id: str,
|
||||
node2_id: str,
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id,
|
||||
node2_id=node2_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
21
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
21
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
21
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
21
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict,
|
||||
description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
|
||||
-> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
716
api/core/workflow/graph_engine/graph_engine.py
Normal file
716
api/core/workflow/graph_engine/graph_engine.py
Normal file
|
@ -0,0 +1,716 @@
|
|||
import logging
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeType,
|
||||
UserFrom,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def __init__(self, max_workers=None, thread_name_prefix='',
|
||||
initializer=None, initargs=(), max_submit_count=100) -> None:
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
return super().submit(fn, *args, **kwargs)
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
thread_pool_max_submit_count = 100
|
||||
thread_pool_max_workers = 10
|
||||
|
||||
## init thread pool
|
||||
if thread_pool_id:
|
||||
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
|
||||
self.is_main_thread_pool = False
|
||||
else:
|
||||
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
|
||||
self.thread_pool_id = str(uuid.uuid4())
|
||||
self.is_main_thread_pool = True
|
||||
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
|
||||
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
stream_processor_cls = AnswerStreamProcessor
|
||||
else:
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {})
|
||||
elif item.node_type == NodeType.ANSWER:
|
||||
if "answer" not in self.graph_runtime_state.outputs:
|
||||
self.graph_runtime_state.outputs["answer"] = ""
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else "")
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(error=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise e
|
||||
finally:
|
||||
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
start_node_id: str,
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
parallel_start_node_id = start_node_id
|
||||
|
||||
next_node_id = start_node_id
|
||||
previous_route_node_state: Optional[RouteNodeState] = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
if isinstance(item, NodeRunStartedEvent):
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
item.route_node_state.index = self.graph_runtime_state.node_run_steps
|
||||
|
||||
yield item
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
self.graph_runtime_state.node_run_state.add_route(
|
||||
source_node_state_id=previous_route_node_state.id,
|
||||
target_node_state_id=route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
if not edge_mappings:
|
||||
break
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state
|
||||
)
|
||||
|
||||
if not result:
|
||||
break
|
||||
|
||||
next_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id = None
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
condition_edge_mappings = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
if run_condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[run_condition_hash] = []
|
||||
|
||||
condition_edge_mappings[run_condition_hash].append(edge)
|
||||
|
||||
for _, sub_edge_mappings in condition_edge_mappings.items():
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
break
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_branches(
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
node_id = edge_mappings[0].target_node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
node_title = node_config.get('data', {}).get('title')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
futures = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
if (
|
||||
edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
|
||||
):
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
self.thread_pool.submit(self._run_parallel_node, **{
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
)
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if final_node_id:
|
||||
yield final_node_id
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: queue.Queue,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
q.put(ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=parallel_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
except GraphRunFailedError as e:
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=e.error
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=str(e)
|
||||
))
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node_id] + variable_key_list,
|
||||
variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
:param start_at: start time
|
||||
:param max_execution_time: max execution time
|
||||
:return:
|
||||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
|
@ -1,9 +1,8 @@
|
|||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
|
@ -19,24 +18,26 @@ class AnswerNode(BaseNode):
|
|||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
for part in generate_routes:
|
||||
if part.type == "var":
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get(value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value:
|
||||
answer += value.markdown
|
||||
else:
|
||||
|
@ -51,70 +52,16 @@ class AnswerNode(BaseNode):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
@ -126,6 +73,6 @@ class AnswerNode(BaseNode):
|
|||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
|
169
api/core/workflow/nodes/answer/answer_stream_generate_router.py
Normal file
169
api/core/workflow/nodes/answer/answer_stream_generate_router.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
answer_generate_route[answer_node_id] = generate_route
|
||||
|
||||
# fetch answer dependencies
|
||||
answer_node_ids = list(answer_generate_route.keys())
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = {}
|
||||
for answer_node_id in answer_node_ids:
|
||||
if answer_dependencies.get(answer_node_id) is None:
|
||||
answer_dependencies[answer_node_id] = []
|
||||
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=answer_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
221
api/core/workflow/nodes/answer/answer_stream_processor.py
Normal file
221
api/core/workflow/nodes/answer/answer_stream_processor.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = cls._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = cls._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
71
api/core/workflow/nodes/answer/base_stream_processor.py
Normal file
71
api/core/workflow/nodes/answer/base_stream_processor.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
|
@ -1,5 +1,6 @@
|
|||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData):
|
|||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
answer: str
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
type: str
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
type: str = "var"
|
||||
value_selector: list[str]
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: list[str] = Field(..., description="value selector")
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
type: str = "text"
|
||||
text: str
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
|
||||
|
||||
class AnswerNodeDoubleLink(BaseModel):
|
||||
node_id: str = Field(..., description="node id")
|
||||
source_node_ids: list[str] = Field(..., description="source node ids")
|
||||
target_node_ids: list[str] = Field(..., description="target node ids")
|
||||
|
||||
|
||||
class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
...,
|
||||
description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
|
|
|
@ -1,142 +1,103 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: Sequence[WorkflowCallback]
|
||||
|
||||
is_answer_previous_node: bool = False
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
def __init__(self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
callbacks: Sequence[WorkflowCallback] | None = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_type = graph_init_params.workflow_type
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.previous_node_id = previous_node_id
|
||||
self.thread_pool_id = thread_pool_id
|
||||
|
||||
# TODO: May need to check if key exists.
|
||||
self.node_id = config["id"]
|
||||
if not self.node_id:
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) \
|
||||
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
self.node_run_result = result
|
||||
return result
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
result = self._run()
|
||||
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
:param value_selector: value selector
|
||||
:return:
|
||||
"""
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text,
|
||||
metadata={
|
||||
"node_type": self.node_type,
|
||||
"is_answer_previous_node": self.is_answer_previous_node,
|
||||
"value_selector": value_selector
|
||||
}
|
||||
)
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(
|
||||
run_result=result
|
||||
)
|
||||
else:
|
||||
yield from result
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict):
|
||||
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(node_data)
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: BaseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
@ -158,38 +119,3 @@ class BaseNode(ABC):
|
|||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
class BaseIterationNode(BaseNode):
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
return self._run(variable_pool=variable_pool)
|
||||
|
||||
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
return self._get_next_iteration(variable_pool, state)
|
||||
|
||||
@abstractmethod
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, Union, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
|
@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
|||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
@ -33,13 +33,13 @@ class CodeNode(BaseNode):
|
|||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run code
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(CodeNodeData, self.node_data)
|
||||
node_data = self.node_data
|
||||
node_data = cast(CodeNodeData, node_data)
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
|
@ -49,7 +49,7 @@ class CodeNode(BaseNode):
|
|||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
variables[variable] = value
|
||||
# Run code
|
||||
|
@ -311,13 +311,19 @@ class CodeNode(BaseNode):
|
|||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
@ -12,10 +11,9 @@ class EndNode(BaseNode):
|
|||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
|
@ -24,7 +22,7 @@ class EndNode(BaseNode):
|
|||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
|
@ -34,52 +32,16 @@ class EndNode(BaseNode):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes
|
||||
:param graph: graph
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes from node data
|
||||
:param graph: graph
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
generate_nodes = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_mapping:
|
||||
node = node_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
||||
generate_nodes.append(node_id)
|
||||
|
||||
# remove duplicates
|
||||
generate_nodes = list(set(generate_nodes))
|
||||
|
||||
return generate_nodes
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
|
148
api/core/workflow/nodes/end/end_stream_generate_router.py
Normal file
148
api/core/workflow/nodes/end/end_stream_generate_router.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str]
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
if end_node_id in node_parallel_mapping:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
|
||||
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
|
||||
|
||||
# fetch end dependencies
|
||||
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
node_data: EndNodeData) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
value_selectors = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == 'text'
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
|
||||
-> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = EndNodeData(**config.get("data", {}))
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = {}
|
||||
for end_node_id in end_node_ids:
|
||||
if end_dependencies.get(end_node_id) is None:
|
||||
end_dependencies[end_node_id] = []
|
||||
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=end_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
:param end_node_id: end node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param end_dependencies: end dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
191
api/core/workflow/nodes/end/end_stream_processor.py
Normal file
191
api/core/workflow/nodes/end/end_stream_processor.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_outputed = False
|
||||
self.outputed_node_ids = set()
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_end_node_ids
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if (event.route_node_state.node_id != end_node_id
|
||||
and (end_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
|
||||
position = 0
|
||||
value_selectors = []
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position >= route_position:
|
||||
value_selectors.append(current_value_selectors)
|
||||
|
||||
position += 1
|
||||
|
||||
for value_selector in value_selectors:
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||
text = '\n' + text
|
||||
|
||||
self.outputed_node_ids.add(current_node_id)
|
||||
self.has_outputed = True
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_end_node_ids = []
|
||||
for end_node_id, route_position in self.route_position.items():
|
||||
if end_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
position = 0
|
||||
value_selector = None
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position == route_position:
|
||||
value_selector = current_value_selectors
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_end_node_ids.append(end_node_id)
|
||||
|
||||
return stream_out_end_node_ids
|
|
@ -1,3 +1,5 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
|
|||
END Node Data.
|
||||
"""
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
|
||||
class EndStreamParam(BaseModel):
|
||||
"""
|
||||
EndStreamParam entity
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="end dependencies (end node id -> dependent node ids)"
|
||||
)
|
||||
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
|
||||
...,
|
||||
description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
)
|
||||
|
|
20
api/core/workflow/nodes/event.py
Normal file
20
api/core/workflow/nodes/event.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class RunStreamChunkEvent(BaseModel):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent
|
|
@ -1,15 +1,14 @@
|
|||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeData,
|
||||
|
@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode):
|
|||
},
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
|
||||
# TODO: Switch to use segment directly
|
||||
if node_data.authorization.config and node_data.authorization.config.api_key:
|
||||
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
|
||||
node_data.authorization.config.api_key = parser.convert_template(
|
||||
template=node_data.authorization.config.api_key,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
).text
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
try:
|
||||
http_executor = HttpExecutor(
|
||||
node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
|
||||
node_data=node_data,
|
||||
timeout=self._get_request_timeout(node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# invoke http executor
|
||||
|
@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode):
|
|||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(HttpRequestNodeData, node_data)
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
|
||||
|
||||
|
@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode):
|
|||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
except Exception as e:
|
||||
|
|
|
@ -3,20 +3,7 @@ from typing import Literal, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
@ -15,31 +12,35 @@ class IfElseNode(BaseNode):
|
|||
_node_data_cls = IfElseNodeData
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs = {
|
||||
node_inputs: dict[str, list] = {
|
||||
"conditions": []
|
||||
}
|
||||
|
||||
process_datas = {
|
||||
process_datas: dict[str, list] = {
|
||||
"condition_results": []
|
||||
}
|
||||
|
||||
input_conditions = []
|
||||
final_result = False
|
||||
selected_case_id = None
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if node_data.cases:
|
||||
for case in node_data.cases:
|
||||
input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
|
||||
|
||||
|
@ -58,7 +59,10 @@ class IfElseNode(BaseNode):
|
|||
|
||||
else:
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=node_data.conditions
|
||||
)
|
||||
|
||||
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
|
||||
|
||||
|
@ -94,376 +98,17 @@ class IfElseNode(BaseNode):
|
|||
|
||||
return data
|
||||
|
||||
def evaluate_condition(
|
||||
self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate condition
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:param comparison_operator: comparison operator
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if comparison_operator == "contains":
|
||||
return self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
return self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
return self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
return self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
return self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
return self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
return self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
return self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
return self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
return self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
return self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
return self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
return self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
return self._assert_not_null(actual_value)
|
||||
elif comparison_operator == "regex match":
|
||||
return self._assert_regex_match(actual_value, expected_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
for condition in conditions:
|
||||
actual_variable = variable_pool.get_any(condition.variable_selector)
|
||||
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
expected_value = variable_template_parser.extract_variable_selectors()
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
else:
|
||||
expected_value = condition.value
|
||||
else:
|
||||
expected_value = None
|
||||
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_variable,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
return re.search(expected_value, actual_value) is not None
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
|
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
|
|||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
pass
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
|
|
|
@ -1,124 +1,371 @@
|
|||
from typing import cast
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IterationNode(BaseIterationNode):
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(self.node_data.iterator_selector)
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
|
||||
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
|
||||
'iterator_selector': iterator
|
||||
}, outputs=[], metadata=IterationState.MetaData(
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
if not iterator_list_segment:
|
||||
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
return state
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
# resolve current output
|
||||
self._resolve_current_output(variable_pool, state)
|
||||
# move to next iteration
|
||||
self._next_iteration(variable_pool, state)
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
if self._reached_iteration_limit(variable_pool, state):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs = {
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
|
||||
graph_config = self.graph_config
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
||||
iteration_leaf_node_ids = []
|
||||
for leaf_node_id in leaf_node_ids:
|
||||
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
|
||||
if not node_config:
|
||||
continue
|
||||
|
||||
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
|
||||
if not leaf_node_iteration_id:
|
||||
continue
|
||||
|
||||
if leaf_node_iteration_id != self.node_id:
|
||||
continue
|
||||
|
||||
iteration_leaf_node_ids.append(leaf_node_id)
|
||||
|
||||
# add condition of end nodes to root node
|
||||
iteration_graph.add_extra_edge(
|
||||
source_node_id=leaf_node_id,
|
||||
target_node_id=root_node_id,
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[
|
||||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=str(len(iterator_list_value))
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
0
|
||||
)
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[0]
|
||||
)
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
"iterator_length": len(iterator_list_value)
|
||||
},
|
||||
predecessor_node_id=self.previous_node_id
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
try:
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
|
||||
# handle iteration run result
|
||||
if event.route_node_state.node_id in iteration_leaf_node_ids:
|
||||
# append to iteration output variable list
|
||||
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
||||
outputs.append(current_iteration_output)
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(
|
||||
current_iteration_output) if current_iteration_output else None
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
'output': jsonable_encoder(state.outputs)
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
}
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(outputs)
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# iteration run failed
|
||||
logger.exception("Iteration run failed")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return node_data.start_node_id
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
|
||||
if isinstance(output, list):
|
||||
state.outputs.extend(output)
|
||||
else:
|
||||
state.outputs.append(output)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
||||
variable_mapping = {
|
||||
f'{node_id}.input_selector': node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
|
39
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
39
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
|
@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get_any(node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
|
||||
query = variable
|
||||
variables = {
|
||||
'query': query
|
||||
|
@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.exception("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
|
@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
return context_list
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {}
|
||||
variable_mapping['query'] = node_data.query_variable_selector
|
||||
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
|
@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
|
@ -43,17 +46,26 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
@ -80,10 +92,15 @@ class LLMNode(BaseNode):
|
|||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
|
@ -115,19 +132,34 @@ class LLMNode(BaseNode):
|
|||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage, finish_reason = self._invoke_llm(
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompleted):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
|
@ -135,22 +167,26 @@ class LLMNode(BaseNode):
|
|||
'finish_reason': finish_reason
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
},
|
||||
llm_usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
|
@ -170,23 +206,31 @@ class LLMNode(BaseNode):
|
|||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage, finish_reason = self._handle_invoke_result(
|
||||
generator = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
yield event
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
usage = event.usage
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, finish_reason
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
return
|
||||
|
||||
model = None
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
finish_reason = None
|
||||
|
@ -194,7 +238,10 @@ class LLMNode(BaseNode):
|
|||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
@ -211,11 +258,15 @@ class LLMNode(BaseNode):
|
|||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage, finish_reason
|
||||
yield ModelInvokeCompleted(
|
||||
text=full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
|
@ -224,13 +275,13 @@ class LLMNode(BaseNode):
|
|||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2':
|
||||
if messages.edition_type == 'jinja2' and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2':
|
||||
if message.edition_type == 'jinja2' and message.jinja2_text:
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
|
@ -348,7 +399,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
|
@ -356,15 +407,18 @@ class LLMNode(BaseNode):
|
|||
:return:
|
||||
"""
|
||||
if not node_data.context.enabled:
|
||||
return None
|
||||
return
|
||||
|
||||
if not node_data.context.variable_selector:
|
||||
return None
|
||||
return
|
||||
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[],
|
||||
context=context_value
|
||||
)
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
original_retriever_resource = []
|
||||
|
@ -381,17 +435,10 @@ class LLMNode(BaseNode):
|
|||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
if self.callbacks and original_retriever_resource:
|
||||
for callback in self.callbacks:
|
||||
callback.on_event(
|
||||
event=QueueRetrieverResourcesEvent(
|
||||
retriever_resources=original_retriever_resource
|
||||
)
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
|
@ -574,7 +621,8 @@ class LLMNode(BaseNode):
|
|||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
|
||||
content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
|
@ -646,13 +694,19 @@ class LLMNode(BaseNode):
|
|||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
|
@ -702,6 +756,10 @@ class LLMNode(BaseNode):
|
|||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,20 +1,34 @@
|
|||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> LoopState:
|
||||
return super()._run(variable_pool)
|
||||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[]
|
||||
)]
|
||||
|
|
37
api/core/workflow/nodes/node_mapping.py
Normal file
37
api/core/workflow/nodes/node_mapping.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode):
|
|||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
variable = variable_pool.get_any(node_data.query)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
|
||||
if not variable:
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
query = variable
|
||||
|
@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode):
|
|||
raise ValueError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, variable_pool, model_config, memory
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
|
||||
query,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
model_config,
|
||||
memory)
|
||||
prompt_message_tools = []
|
||||
|
||||
|
@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode):
|
|||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
},
|
||||
llm_usage=usage
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
|
@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode):
|
|||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
|
||||
str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
|
||||
variable_mapping = {
|
||||
'query': node_data.query
|
||||
}
|
||||
|
@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode):
|
|||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
|
||||
|
@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
|
|||
_node_data_cls = QuestionClassifierNodeData
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
|
@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode):
|
|||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage, finish_reason = self._invoke_llm(
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
try:
|
||||
|
@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode):
|
|||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
},
|
||||
llm_usage=usage
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode):
|
|||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
},
|
||||
llm_usage=usage
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {'query': node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
|
@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode):
|
|||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
@ -11,14 +13,13 @@ class StartNode(BaseNode):
|
|||
_node_data_cls = StartNodeData
|
||||
_node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_inputs = dict(variable_pool.user_inputs)
|
||||
system_inputs = variable_pool.system_variables
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||
|
@ -30,9 +31,16 @@ class StartNode(BaseNode):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
import os
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode):
|
|||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
|
@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode):
|
|||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
variables[variable_name] = value
|
||||
# Run code
|
||||
try:
|
||||
|
@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
|
@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode):
|
|||
'output': result['result']
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ class ToolNode(BaseNode):
|
|||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
@ -56,8 +56,8 @@ class ToolNode(BaseNode):
|
|||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
|
@ -66,6 +66,7 @@ class ToolNode(BaseNode):
|
|||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
|
@ -145,7 +146,8 @@ class ToolNode(BaseNode):
|
|||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
|
||||
-> tuple[str, list[FileVar], list[dict]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
@ -221,9 +223,16 @@ class ToolNode(BaseNode):
|
|||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
@ -239,4 +248,8 @@ class ToolNode(BaseNode):
|
|||
elif input.type == 'constant':
|
||||
pass
|
||||
|
||||
result = {
|
||||
node_id + '.' + key: value for key, value in result.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
|
|||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(VariableAssignerNodeData, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
|
@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
|
|||
|
||||
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
|
||||
for selector in node_data.variables:
|
||||
variable = variable_pool.get_any(selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
if variable is not None:
|
||||
outputs = {
|
||||
"output": variable
|
||||
|
@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode):
|
|||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
for selector in group.variables:
|
||||
variable = variable_pool.get_any(selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {
|
||||
|
@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
|
|
@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
|||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode):
|
|||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
|
@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode):
|
|||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
|
0
api/core/workflow/utils/condition/__init__.py
Normal file
0
api/core/workflow/utils/condition/__init__.py
Normal file
17
api/core/workflow/utils/condition/entities.py
Normal file
17
api/core/workflow/utils/condition/entities.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
383
api/core/workflow/utils/condition/processor.py
Normal file
383
api/core/workflow/utils/condition/processor.py
Normal file
|
@ -0,0 +1,383 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
index = 0
|
||||
for condition in conditions:
|
||||
index += 1
|
||||
actual_value = variable_pool.get_any(
|
||||
condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = None
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
|
||||
if expected_value is None:
|
||||
expected_value = condition.value
|
||||
else:
|
||||
expected_value = condition.value
|
||||
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_value, comparison_operator, expected_value)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
def evaluate_condition(
|
||||
self,
|
||||
actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None],
|
||||
comparison_operator: str,
|
||||
expected_value: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate condition
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:param comparison_operator: comparison operator
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if comparison_operator == "contains":
|
||||
return self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
return self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
return self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
return self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
return self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
return self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
return self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
return self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
return self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
return self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
return self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
return self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
return self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
return self._assert_not_null(actual_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float],
|
||||
expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float],
|
||||
expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConditionAssertionError(Exception):
|
||||
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
|
||||
self.message = message
|
||||
self.conditions = conditions
|
||||
self.sub_condition_compare_results = sub_condition_compare_results
|
||||
super().__init__(self.message)
|
File diff suppressed because it is too large
Load Diff
314
api/core/workflow/workflow_entry.py
Normal file
314
api/core/workflow/workflow_entry.py
Normal file
|
@ -0,0 +1,314 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunEvent
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Init workflow entry
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param workflow_id: workflow id
|
||||
:param workflow_type: workflow type
|
||||
:param graph_config: workflow graph config
|
||||
:param graph: workflow graph
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param thread_pool_id: thread pool id
|
||||
"""
|
||||
# check call depth
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=thread_pool_id
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
:param callbacks: workflow callbacks
|
||||
"""
|
||||
graph_engine = self.graph_engine
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
for event in generator:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
event=GraphRunFailedEvent(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict
|
||||
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: node id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
raise ValueError('nodes not found in workflow graph')
|
||||
|
||||
# fetch node config from node id
|
||||
node_config = None
|
||||
for node in nodes:
|
||||
if node.get('id') == node_id:
|
||||
node_config = node
|
||||
break
|
||||
|
||||
if not node_config:
|
||||
raise ValueError('node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f'Node class not found for node type {node_type}')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=workflow.graph_dict
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
node_instance: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
),
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(
|
||||
node_instance=node_instance,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
|
||||
"""
|
||||
Handle special values
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
new_value = dict(value) if value else {}
|
||||
if isinstance(new_value, dict):
|
||||
for key, val in new_value.items():
|
||||
if isinstance(val, FileVar):
|
||||
new_value[key] = val.to_dict()
|
||||
elif isinstance(val, list):
|
||||
new_val = []
|
||||
for v in val:
|
||||
if isinstance(v, FileVar):
|
||||
new_val.append(v.to_dict())
|
||||
else:
|
||||
new_val.append(v)
|
||||
|
||||
new_value[key] = new_val
|
||||
|
||||
return new_value
|
||||
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData
|
||||
) -> None:
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split('.')
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f'Invalid node variable {node_variable}')
|
||||
|
||||
node_variable_key = '.'.join(node_variable_list[1:])
|
||||
|
||||
if (
|
||||
node_variable_key not in user_inputs
|
||||
and node_variable not in user_inputs
|
||||
) and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {node_variable} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = cast(list[str], variable_key_list)
|
||||
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
# FIXME: temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(input_value, list):
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in input_value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=item.get(
|
||||
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
|
@ -0,0 +1,35 @@
|
|||
"""add node_execution_id into node_executions
|
||||
|
||||
Revision ID: 675b5321501b
|
||||
Revises: 030f4915f36a
|
||||
Create Date: 2024-08-12 10:54:02.259331
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '675b5321501b'
|
||||
down_revision = '030f4915f36a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True))
|
||||
batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_node_execution_id_idx')
|
||||
batch_op.drop_column('node_execution_id')
|
||||
|
||||
# ### end Alembic commands ###
|
|
@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model):
|
|||
'triggered_from', 'workflow_run_id'),
|
||||
db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id',
|
||||
'triggered_from', 'node_id'),
|
||||
db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id',
|
||||
'triggered_from', 'node_execution_id'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
|
@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model):
|
|||
workflow_run_id = db.Column(StringUUID)
|
||||
index = db.Column(db.Integer, nullable=False)
|
||||
predecessor_node_id = db.Column(db.String(255))
|
||||
node_execution_id = db.Column(db.String(255), nullable=True)
|
||||
node_id = db.Column(db.String(255), nullable=False)
|
||||
node_type = db.Column(db.String(255), nullable=False)
|
||||
title = db.Column(db.String(255), nullable=False)
|
||||
|
|
|
@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_dsl_version = "0.1.1"
|
||||
current_dsl_version = "0.1.2"
|
||||
dsl_to_dify_version_mapping: dict[str, str] = {
|
||||
"0.1.2": "0.8.0",
|
||||
"0.1.1": "0.6.0", # dsl version -> from dify version
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
@ -103,9 +104,7 @@ class AppGenerateService:
|
|||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(
|
||||
cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
|
||||
):
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
|
@ -142,7 +141,7 @@ class AppGenerateService:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any:
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
|
||||
"""
|
||||
Get workflow
|
||||
:param app_model: app model
|
||||
|
|
|
@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.segments import Variable
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
@ -172,8 +174,13 @@ class WorkflowService:
|
|||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_configs()
|
||||
default_block_configs = []
|
||||
for node_type, node_class in node_classes.items():
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(default_config)
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
|
@ -182,11 +189,18 @@ class WorkflowService:
|
|||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type = NodeType.value_of(node_type)
|
||||
node_type_enum: NodeType = NodeType.value_of(node_type)
|
||||
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_config(node_type, filters)
|
||||
node_class = node_classes.get(node_type_enum)
|
||||
if not node_class:
|
||||
return None
|
||||
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
def run_draft_workflow_node(
|
||||
self, app_model: App, node_id: str, user_inputs: dict, account: Account
|
||||
|
@ -200,82 +214,68 @@ class WorkflowService:
|
|||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
|
||||
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=e.node_id,
|
||||
node_type=e.node_type.value,
|
||||
title=e.node_title,
|
||||
status=WorkflowNodeExecutionStatus.FAILED.value,
|
||||
error=e.error,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
node_instance = e.node_instance
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
|
||||
return workflow_node_execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = app_model.tenant_id
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
workflow_node_execution.node_type = node_instance.node_type.value
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_succeeded and node_run_result:
|
||||
# create workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type.value,
|
||||
title=node_instance.node_data.title,
|
||||
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
|
||||
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
|
||||
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
|
||||
execution_metadata=(
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
),
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
|
||||
workflow_node_execution.process_data = (
|
||||
json.dumps(node_run_result.process_data) if node_run_result.process_data else None
|
||||
)
|
||||
workflow_node_execution.outputs = (
|
||||
json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
|
||||
)
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
else:
|
||||
# create workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type.value,
|
||||
title=node_instance.node_data.title,
|
||||
status=node_run_result.status.value,
|
||||
error=node_run_result.error,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
@ -321,25 +321,3 @@ class WorkflowService:
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def get_elapsed_time(cls, workflow_run_id: str) -> float:
|
||||
"""
|
||||
Get elapsed time
|
||||
"""
|
||||
elapsed_time = 0.0
|
||||
|
||||
# fetch workflow node execution by workflow_run_id
|
||||
workflow_nodes = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowNodeExecution.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
if not workflow_nodes:
|
||||
return elapsed_time
|
||||
|
||||
for node in workflow_nodes:
|
||||
elapsed_time += node.elapsed_time
|
||||
|
||||
return elapsed_time
|
||||
|
|
|
@ -1,17 +1,72 @@
|
|||
import time
|
||||
import uuid
|
||||
from os import getenv
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
|
||||
|
||||
|
||||
def init_code_node(code_config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"target": "code",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["code", "123", "args1"], 1)
|
||||
variable_pool.add(["code", "123", "args2"], 2)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=code_config,
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """
|
||||
|
@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock):
|
|||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
node = CodeNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "123", "args1"], 1)
|
||||
pool.add(["1", "123", "args2"], 2)
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] == 3
|
||||
assert result.error is None
|
||||
|
||||
|
@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
|||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
node = CodeNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "123", "args1"], 1)
|
||||
pool.add(["1", "123", "args2"], 2)
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Output variable `result` must be a string"
|
||||
|
||||
|
@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth():
|
|||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
node = CodeNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
},
|
||||
"number_validator": {
|
||||
"type": "number",
|
||||
},
|
||||
"number_array_validator": {
|
||||
"type": "array[number]",
|
||||
},
|
||||
"string_array_validator": {
|
||||
"type": "array[string]",
|
||||
},
|
||||
"object_validator": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
},
|
||||
"number_validator": {
|
||||
"type": "number",
|
||||
},
|
||||
"number_array_validator": {
|
||||
"type": "array[number]",
|
||||
},
|
||||
"string_array_validator": {
|
||||
"type": "array[string]",
|
||||
},
|
||||
"object_validator": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth():
|
|||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
|
||||
|
@ -250,35 +284,30 @@ def test_execute_code_output_object_list():
|
|||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
node = CodeNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
},
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
@ -295,6 +324,8 @@ def test_execute_code_output_object_list():
|
|||
]
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
|
||||
|
|
|
@ -1,31 +1,69 @@
|
|||
import time
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
|
||||
BASIC_NODE_DATA = {
|
||||
"tenant_id": "1",
|
||||
"app_id": "1",
|
||||
"workflow_id": "1",
|
||||
"user_id": "1",
|
||||
"user_from": UserFrom.ACCOUNT,
|
||||
"invoke_from": InvokeFrom.WEB_APP,
|
||||
}
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["a", "b123", "args1"], 1)
|
||||
pool.add(["a", "b123", "args2"], 2)
|
||||
def init_http_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_get(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -45,12 +83,11 @@ def test_get(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
|
@ -59,7 +96,7 @@ def test_get(setup_http_mock):
|
|||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_no_auth(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
|
@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock):
|
|||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_header(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
|
@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock):
|
|||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_template(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -143,11 +178,11 @@ def test_template(setup_http_mock):
|
|||
"params": "A:b\nTemplate:{{#a.b123.args2#}}",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
|
@ -158,7 +193,7 @@ def test_template(setup_http_mock):
|
|||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_json(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -178,11 +213,11 @@ def test_json(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'},
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert '{"a": "1"}' in data
|
||||
|
@ -190,7 +225,7 @@ def test_json(setup_http_mock):
|
|||
|
||||
|
||||
def test_x_www_form_urlencoded(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "a=1&b=2" in data
|
||||
|
@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
|||
|
||||
|
||||
def test_form_data(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -242,11 +277,11 @@ def test_form_data(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert 'form-data; name="a"' in data
|
||||
|
@ -257,7 +292,7 @@ def test_form_data(setup_http_mock):
|
|||
|
||||
|
||||
def test_none_data(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -277,11 +312,11 @@ def test_none_data(setup_http_mock):
|
|||
"params": "A:b",
|
||||
"body": {"type": "none", "data": "123123123"},
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "X-Header: 123" in data
|
||||
|
@ -289,7 +324,7 @@ def test_none_data(setup_http_mock):
|
|||
|
||||
|
||||
def test_mock_404(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock):
|
|||
"params": "",
|
||||
"headers": "X-Header:123",
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
|
||||
assert 404 == resp.get("status_code")
|
||||
assert "Not Found" in resp.get("body")
|
||||
assert "Not Found" in resp.get("body", "")
|
||||
|
||||
|
||||
def test_multi_colons_parse(setup_http_mock):
|
||||
node = HttpRequestNode(
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock):
|
|||
"headers": "Referer:http://example3.com\nRedirect:http://example4.com",
|
||||
"body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"},
|
||||
},
|
||||
},
|
||||
**BASIC_NODE_DATA,
|
||||
}
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
|
||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request")
|
||||
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request")
|
||||
assert "http://example3.com" == resp.get("headers").get("referer")
|
||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
|||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_execute_llm(setup_openai_mock):
|
||||
node = LLMNode(
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather today?",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["abc", "output"], "sunny")
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_execute_llm(setup_openai_mock):
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock):
|
|||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather today?",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["abc", "output"], "sunny")
|
||||
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
|
||||
provider_instance = ModelProviderFactory().get_provider_instance("openai")
|
||||
|
@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock):
|
|||
model_type_instance=model_type_instance,
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
|
||||
model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
|
||||
assert model_schema is not None
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
mode="chat",
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
|
||||
model_schema=model_schema,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
|
@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock):
|
|||
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["text"] is not None
|
||||
assert result.outputs["usage"]["total_tokens"] > 0
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
|
@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
node = LLMNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather today?",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["abc", "output"], "sunny")
|
||||
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
|
||||
provider_instance = ModelProviderFactory().get_provider_instance("openai")
|
||||
|
@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
|
||||
assert model_schema is not None
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
mode="chat",
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
|
||||
model_schema=model_schema,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
|
@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "sunny" in json.dumps(result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(result.process_data)
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
@ -8,19 +10,21 @@ import pytest
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
@ -47,13 +51,15 @@ def get_mocked_fetch_model_config(
|
|||
model_type_instance=model_type_instance,
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
|
||||
model_schema = model_type_instance.get_model_schema(model)
|
||||
assert model_schema is not None
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model=model,
|
||||
provider=provider,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_type_instance.get_model_schema(model),
|
||||
model_schema=model_schema,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
|
@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str):
|
|||
return MagicMock(return_value=MemoryMock())
|
||||
|
||||
|
||||
def init_parameter_extractor_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_function_calling_parameter_extractor(setup_openai_mock):
|
||||
"""
|
||||
Test function calling for parameter extractor.
|
||||
"""
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
|||
"reasoning_mode": "function_call",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
|
@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
|||
environment_variables=[],
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == "kawaii"
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
|
@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock):
|
|||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock):
|
|||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == "kawaii"
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
process_data = result.process_data
|
||||
|
||||
assert process_data is not None
|
||||
process_data.get("prompts")
|
||||
|
||||
for prompt in process_data.get("prompts"):
|
||||
for prompt in process_data.get("prompts", []):
|
||||
if prompt.get("role") == "system":
|
||||
assert "what's the weather in SF" in prompt.get("text")
|
||||
|
||||
|
@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
|||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
|||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
prompts = result.process_data.get("prompts")
|
||||
assert result.process_data is not None
|
||||
prompts = result.process_data.get("prompts", [])
|
||||
|
||||
for prompt in prompts:
|
||||
if prompt.get("role") == "user":
|
||||
|
@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
|||
"""
|
||||
Test completion parameter extractor.
|
||||
"""
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
|||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
assert len(result.process_data.get("prompts")) == 1
|
||||
assert "SF" in result.process_data.get("prompts")[0].get("text")
|
||||
assert result.process_data is not None
|
||||
assert len(result.process_data.get("prompts", [])) == 1
|
||||
assert "SF" in result.process_data.get("prompts", [])[0].get("text")
|
||||
|
||||
|
||||
def test_extract_json_response():
|
||||
|
@ -322,13 +325,7 @@ def test_extract_json_response():
|
|||
Test extract json response.
|
||||
"""
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -357,6 +354,7 @@ def test_extract_json_response():
|
|||
hello world.
|
||||
""")
|
||||
|
||||
assert result is not None
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
|
@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
|||
"""
|
||||
Test chat parameter extractor with memory.
|
||||
"""
|
||||
node = ParameterExtractorNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
|
@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
|||
node._fetch_memory = get_mocked_fetch_memory("customized memory")
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
prompts = result.process_data.get("prompts")
|
||||
assert result.process_data is not None
|
||||
prompts = result.process_data.get("prompts", [])
|
||||
|
||||
latest_role = None
|
||||
for prompt in prompts:
|
||||
|
|
|
@ -1,46 +1,84 @@
|
|||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """{{args2}}"""
|
||||
node = TemplateTransformNode(
|
||||
config = {
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.END_USER,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
},
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "123", "args1"], 1)
|
||||
pool.add(["1", "123", "args2"], 3)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["1", "123", "args1"], 1)
|
||||
variable_pool.add(["1", "123", "args2"], 3)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["output"] == "3"
|
||||
|
|
|
@ -1,21 +1,62 @@
|
|||
import time
|
||||
import uuid
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def init_tool_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
return ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "123", "args1"], "1+1")
|
||||
|
||||
node = ToolNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -34,28 +75,22 @@ def test_tool_variable_invoke():
|
|||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert "2" in result.outputs["text"]
|
||||
assert result.outputs["files"] == []
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "args1"], "1+1")
|
||||
|
||||
node = ToolNode(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
|
@ -74,12 +109,15 @@ def test_tool_mixed_invoke():
|
|||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert "2" in result.outputs["text"]
|
||||
assert result.outputs["files"] == []
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
# Getting the absolute path of the current file's directory
|
||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Getting the absolute path of the project's root directory
|
||||
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
|
||||
|
||||
CACHED_APP = Flask(__name__)
|
||||
CACHED_APP.config.update({"TESTING": True})
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app() -> Flask:
|
||||
return CACHED_APP
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _provide_app_context(app: Flask):
|
||||
with app.app_context():
|
||||
yield
|
||||
|
|
791
api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Normal file
791
api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Normal file
|
@ -0,0 +1,791 @@
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-source-answer-target",
|
||||
"source": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "start-source-qc-target",
|
||||
"source": "start",
|
||||
"target": "qc",
|
||||
},
|
||||
{
|
||||
"id": "qc-1-llm-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "1",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "qc-2-http-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "2",
|
||||
"target": "http",
|
||||
},
|
||||
{
|
||||
"id": "http-source-answer2-target",
|
||||
"source": "http",
|
||||
"target": "answer2",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "question-classifier"},
|
||||
"id": "qc",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
},
|
||||
"id": "http",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
start_node_id = "start"
|
||||
|
||||
assert graph.root_node_id == start_node_id
|
||||
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
|
||||
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
|
||||
|
||||
|
||||
def test__init_iteration_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-answer",
|
||||
"source": "llm",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "iteration-source-llm-target",
|
||||
"source": "iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
|
||||
"source": "template-transform-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "llm-in-iteration-source-answer-in-iteration-target",
|
||||
"source": "llm-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"sourceHandle": "source",
|
||||
"target": "code",
|
||||
},
|
||||
{
|
||||
"id": "code-source-iteration-target",
|
||||
"source": "code",
|
||||
"sourceHandle": "source",
|
||||
"target": "iteration",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "iteration"},
|
||||
"id": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
},
|
||||
"id": "template-transform-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
|
||||
graph.add_extra_edge(
|
||||
source_node_id="answer-in-iteration",
|
||||
target_node_id="template-transform-in-iteration",
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")],
|
||||
),
|
||||
)
|
||||
|
||||
# iteration:
|
||||
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
|
||||
|
||||
assert graph.root_node_id == "template-transform-in-iteration"
|
||||
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
|
||||
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
|
||||
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
|
||||
|
||||
|
||||
def test_parallels_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
start_edges = graph.edge_mapping.get("start")
|
||||
assert start_edges is not None
|
||||
assert start_edges[i].target_node_id == f"llm{i+1}"
|
||||
|
||||
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
|
||||
assert llm_edges is not None
|
||||
assert llm_edges[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph2():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
if i < 2:
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph3():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph4():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code3-target",
|
||||
"source": "llm3",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
|
||||
assert graph.edge_mapping.get(f"code{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph5():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code1-target",
|
||||
"source": "llm2",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code2-target",
|
||||
"source": "llm3",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-code2-target",
|
||||
"source": "llm4",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-code3-target",
|
||||
"source": "llm5",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(5):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm3") is not None
|
||||
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm4") is not None
|
||||
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm5") is not None
|
||||
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 8
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph6():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code2-target",
|
||||
"source": "llm1",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code3-target",
|
||||
"source": "llm2",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code3") is not None
|
||||
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 2
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
parent_parallel = None
|
||||
child_parallel = None
|
||||
for p_id, parallel in graph.parallel_mapping.items():
|
||||
if parallel.parent_parallel_id is None:
|
||||
parent_parallel = parallel
|
||||
else:
|
||||
child_parallel = parallel
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code3"]:
|
||||
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
|
||||
|
||||
for node_id in ["code1", "code2"]:
|
||||
assert graph.node_parallel_mapping[node_id] == child_parallel.id
|
|
@ -0,0 +1,505 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "llm1",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "llm1",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "llm2",
|
||||
"target": "end1",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "llm3",
|
||||
"target": "end2",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
"title": "start",
|
||||
"variables": [
|
||||
{
|
||||
"label": "query",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": True,
|
||||
"type": "text-input",
|
||||
"variable": "query",
|
||||
}
|
||||
],
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm1",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say hi"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm2",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say bye"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm3",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say good morning"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end1",
|
||||
"outputs": [
|
||||
{"value_selector": ["llm2", "text"], "variable": "result2"},
|
||||
{"value_selector": ["start", "query"], "variable": "query"},
|
||||
],
|
||||
},
|
||||
"id": "end1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end2",
|
||||
"outputs": [
|
||||
{"value_selector": ["llm1", "text"], "variable": "result1"},
|
||||
{"value_selector": ["llm3", "text"], "variable": "result3"},
|
||||
],
|
||||
},
|
||||
"id": "end2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
def llm_generator(self):
|
||||
contents = ["hi", "bye", "good morning"]
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||
NodeRunMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 18
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == "start"
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == "start"
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "answer1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "answer1",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "answer1",
|
||||
"target": "answer3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "answer2",
|
||||
"target": "answer4",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "answer3",
|
||||
"target": "answer5",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "start"}, "id": "start"},
|
||||
{"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "2"},
|
||||
"id": "answer2",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer3", "answer": "3"},
|
||||
"id": "answer3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer4", "answer": "4"},
|
||||
"id": "answer4",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer5", "answer": "5"},
|
||||
"id": "answer5",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||
"answer2",
|
||||
"answer3",
|
||||
"answer4",
|
||||
"answer5",
|
||||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 23
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == "start"
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == "start"
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_branch(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "if-else-1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "if-else-1",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-1",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "if-else-1",
|
||||
"sourceHandle": "false",
|
||||
"target": "if-else-2",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "if-else-2",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "if-else-2",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-3",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"label": "uid",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": True,
|
||||
"type": "text-input",
|
||||
"variable": "uid",
|
||||
}
|
||||
],
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []},
|
||||
"id": "answer-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
|
||||
"value": "hi",
|
||||
"varType": "string",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"id": "true",
|
||||
"logical_operator": "and",
|
||||
}
|
||||
],
|
||||
"desc": "",
|
||||
"title": "IF/ELSE",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
|
||||
"value": "takatost",
|
||||
"varType": "string",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"id": "true",
|
||||
"logical_operator": "and",
|
||||
}
|
||||
],
|
||||
"title": "IF/ELSE 2",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "2",
|
||||
"title": "Answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "3",
|
||||
"title": "Answer 3",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-3",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "hi",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
|
||||
assert len(items) == 10
|
||||
assert items[3].route_node_state.node_id == "if-else-1"
|
||||
assert items[4].route_node_state.node_id == "if-else-1"
|
||||
assert isinstance(items[5], NodeRunStreamChunkEvent)
|
||||
assert items[5].chunk_content == "1 "
|
||||
assert isinstance(items[6], NodeRunStreamChunkEvent)
|
||||
assert items[6].chunk_content == "takato"
|
||||
assert items[7].route_node_state.node_id == "answer-1"
|
||||
assert items[8].route_node_state.node_id == "answer-1"
|
||||
assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato"
|
||||
assert isinstance(items[9], GraphRunSucceededEvent)
|
||||
|
||||
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
|
@ -0,0 +1,82 @@
|
|||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["start", "weather"], "sunny")
|
||||
pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
|
@ -0,0 +1,109 @@
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping
|
||||
)
|
||||
|
||||
assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"]
|
||||
assert answer_stream_generate_route.answer_dependencies["answer2"] == []
|
|
@ -0,0 +1,216 @@
|
|||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
if next_node_id == "start":
|
||||
yield from _publish_events(graph, next_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _publish_events(graph, edge.target_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _recursive_process(graph, edge.target_node_id)
|
||||
|
||||
|
||||
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
|
||||
parallel_id = graph.node_parallel_mapping.get(next_node_id)
|
||||
parallel_start_node_id = None
|
||||
if parallel_id:
|
||||
parallel = graph.parallel_mapping.get(parallel_id)
|
||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
if "llm" in next_node_id:
|
||||
length = int(next_node_id[-1])
|
||||
for i in range(0, length):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
chunk_content=str(i),
|
||||
route_node_state=route_node_state,
|
||||
from_variable_selector=[next_node_id, "text"],
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
|
||||
def test_process():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool)
|
||||
|
||||
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||
# print("")
|
||||
for event in _recursive_process(graph, "start"):
|
||||
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if "llm" in event.route_node_state.node_id:
|
||||
variable_pool.add(
|
||||
[event.route_node_state.node_id, "text"],
|
||||
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))),
|
||||
)
|
||||
yield event
|
||||
|
||||
result_generator = answer_stream_processor.process(graph_generator())
|
||||
stream_contents = ""
|
||||
for event in result_generator:
|
||||
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
stream_contents += event.chunk_content
|
||||
pass
|
||||
|
||||
assert stream_contents == "c012da01b"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user