From dabfd74622a613c7198f06791ad70424ac94f54f Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Sep 2024 15:23:16 +0800 Subject: [PATCH] feat: Parallel Execution of Nodes in Workflows (#8192) Co-authored-by: StyleZhang Co-authored-by: Yi Co-authored-by: -LAN- --- api/configs/packaging/__init__.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 167 +-- api/core/app/apps/advanced_chat/app_runner.py | 300 ++--- .../advanced_chat/generate_task_pipeline.py | 660 ++++------- .../workflow_event_trigger_callback.py | 203 ---- .../base_app_generate_response_converter.py | 2 +- api/core/app/apps/base_app_runner.py | 6 +- api/core/app/apps/workflow/app_generator.py | 65 +- api/core/app/apps/workflow/app_runner.py | 149 ++- .../apps/workflow/generate_task_pipeline.py | 433 +++---- .../workflow_event_trigger_callback.py | 200 ---- api/core/app/apps/workflow_app_runner.py | 379 +++++++ .../app/apps/workflow_logging_callback.py | 280 +++-- api/core/app/entities/queue_entities.py | 178 ++- api/core/app/entities/task_entities.py | 156 +-- .../based_generate_task_pipeline.py | 14 +- .../app/task_pipeline/message_cycle_manage.py | 36 +- .../task_pipeline/workflow_cycle_manage.py | 663 ++++++----- .../workflow_cycle_state_manager.py | 16 - .../workflow_iteration_cycle_manage.py | 290 ----- .../model_runtime/entities/llm_entities.py | 33 + api/core/moderation/output_moderation.py | 8 +- api/core/tools/tool/workflow_tool.py | 4 +- api/core/tools/tool_engine.py | 2 + api/core/tools/tool_manager.py | 3 +- api/core/tools/utils/message_transformer.py | 1 + .../callbacks/base_workflow_callback.py | 113 +- .../entities/base_node_data_entities.py | 2 +- api/core/workflow/entities/node_entities.py | 34 +- api/core/workflow/entities/variable_pool.py | 82 +- .../workflow/entities/workflow_entities.py | 5 +- api/core/workflow/errors.py | 10 +- api/core/workflow/graph_engine/__init__.py | 0 .../condition_handlers/__init__.py | 0 .../condition_handlers/base_handler.py | 31 + .../branch_identify_handler.py | 28 + .../condition_handlers/condition_handler.py | 32 + .../condition_handlers/condition_manager.py | 35 + .../graph_engine/entities/__init__.py | 0 .../workflow/graph_engine/entities/event.py | 163 +++ .../workflow/graph_engine/entities/graph.py | 692 ++++++++++++ .../entities/graph_init_params.py | 21 + .../entities/graph_runtime_state.py | 27 + .../graph_engine/entities/next_graph_node.py | 13 + .../graph_engine/entities/run_condition.py | 21 + .../entities/runtime_route_state.py | 111 ++ .../workflow/graph_engine/graph_engine.py | 716 ++++++++++++ api/core/workflow/nodes/answer/answer_node.py | 91 +- .../answer/answer_stream_generate_router.py | 169 +++ .../nodes/answer/answer_stream_processor.py | 221 ++++ .../nodes/answer/base_stream_processor.py | 71 ++ api/core/workflow/nodes/answer/entities.py | 42 +- api/core/workflow/nodes/base_node.py | 194 +--- api/core/workflow/nodes/code/code_node.py | 24 +- api/core/workflow/nodes/end/end_node.py | 62 +- .../nodes/end/end_stream_generate_router.py | 148 +++ .../nodes/end/end_stream_processor.py | 191 ++++ api/core/workflow/nodes/end/entities.py | 16 + api/core/workflow/nodes/event.py | 20 + .../nodes/http_request/http_request_node.py | 28 +- api/core/workflow/nodes/if_else/entities.py | 15 +- .../workflow/nodes/if_else/if_else_node.py | 405 +------ api/core/workflow/nodes/iteration/entities.py | 9 +- .../nodes/iteration/iteration_node.py | 435 +++++-- .../nodes/iteration/iteration_start_node.py | 39 + .../knowledge_retrieval_node.py | 32 +- api/core/workflow/nodes/llm/llm_node.py | 166 ++- api/core/workflow/nodes/loop/loop_node.py | 30 +- api/core/workflow/nodes/node_mapping.py | 37 + .../parameter_extractor_node.py | 35 +- .../question_classifier_node.py | 50 +- api/core/workflow/nodes/start/start_node.py | 22 +- .../template_transform_node.py | 26 +- api/core/workflow/nodes/tool/tool_node.py | 23 +- .../variable_aggregator_node.py | 25 +- .../workflow/nodes/variable_assigner/node.py | 13 +- api/core/workflow/utils/condition/__init__.py | 0 api/core/workflow/utils/condition/entities.py | 17 + .../workflow/utils/condition/processor.py | 383 +++++++ api/core/workflow/workflow_engine_manager.py | 1005 ----------------- api/core/workflow/workflow_entry.py | 314 +++++ ...21501b_add_node_execution_id_into_node_.py | 35 + api/models/workflow.py | 3 + api/services/app_dsl_service.py | 3 +- api/services/app_generate_service.py | 7 +- api/services/workflow_service.py | 158 ++- .../workflow/nodes/test_code.py | 331 +++--- .../workflow/nodes/test_http.py | 156 ++- .../workflow/nodes/test_llm.py | 131 ++- .../nodes/test_parameter_extractor.py | 192 ++-- .../workflow/nodes/test_template_transform.py | 84 +- .../workflow/nodes/test_tool.py | 94 +- api/tests/unit_tests/conftest.py | 17 + .../core/workflow/graph_engine/__init__.py | 0 .../core/workflow/graph_engine/test_graph.py | 791 +++++++++++++ .../graph_engine/test_graph_engine.py | 505 +++++++++ .../core/workflow/nodes/answer/__init__.py | 0 .../core/workflow/nodes/answer/test_answer.py | 82 ++ .../test_answer_stream_generate_router.py | 109 ++ .../answer/test_answer_stream_processor.py | 216 ++++ .../core/workflow/nodes/iteration/__init__.py | 0 .../nodes/iteration/test_iteration.py | 420 +++++++ .../core/workflow/nodes/test_answer.py | 65 +- .../core/workflow/nodes/test_if_else.py | 126 ++- .../workflow/nodes/test_variable_assigner.py | 205 +++- docker-legacy/docker-compose.yaml | 6 +- docker/docker-compose.yaml | 6 +- .../chat/chat/answer/workflow-process.tsx | 18 +- web/app/components/base/chat/chat/hooks.ts | 21 +- .../share/text-generation/result/index.tsx | 22 +- .../components/workflow/candidate-node.tsx | 4 + web/app/components/workflow/constants.ts | 21 +- .../workflow/hooks/use-nodes-interactions.ts | 235 ++-- .../workflow/hooks/use-workflow-run.ts | 111 +- .../workflow/hooks/use-workflow-template.ts | 6 +- .../components/workflow/hooks/use-workflow.ts | 58 +- web/app/components/workflow/index.tsx | 5 + web/app/components/workflow/limit-tips.tsx | 39 + .../nodes/_base/components/next-step/add.tsx | 24 +- .../_base/components/next-step/container.tsx | 55 + .../_base/components/next-step/index.tsx | 88 +- .../nodes/_base/components/next-step/item.tsx | 88 +- .../nodes/_base/components/next-step/line.tsx | 90 +- .../_base/components/next-step/operator.tsx | 129 +++ .../nodes/_base/components/node-handle.tsx | 42 +- .../nodes/_base/components/node-resizer.tsx | 4 +- .../components/workflow/nodes/_base/node.tsx | 10 + .../workflow/nodes/if-else/types.ts | 1 - .../workflow/nodes/if-else/utils.ts | 1 - .../nodes/iteration-start/constants.ts | 1 + .../workflow/nodes/iteration-start/default.ts | 21 + .../workflow/nodes/iteration-start/index.tsx | 42 + .../workflow/nodes/iteration-start/types.ts | 3 + .../workflow/nodes/iteration/add-block.tsx | 95 +- .../workflow/nodes/iteration/default.ts | 1 + .../workflow/nodes/iteration/insert-block.tsx | 61 - .../workflow/nodes/iteration/node.tsx | 20 +- .../nodes/iteration/use-interactions.ts | 6 +- .../workflow/operator/add-block.tsx | 2 +- web/app/components/workflow/operator/hooks.ts | 2 +- .../workflow/panel/debug-and-preview/hooks.ts | 107 +- web/app/components/workflow/run/index.tsx | 21 +- .../workflow/run/iteration-result-panel.tsx | 78 +- web/app/components/workflow/run/node.tsx | 53 +- .../components/workflow/run/tracing-panel.tsx | 264 ++++- web/app/components/workflow/store.ts | 4 + web/app/components/workflow/types.ts | 3 +- web/app/components/workflow/utils.ts | 311 ++++- web/i18n/en-US/workflow.ts | 17 +- web/i18n/zh-Hans/workflow.ts | 17 +- web/package.json | 2 +- web/service/base.ts | 18 +- web/themes/dark.css | 1 + web/themes/light.css | 1 + web/themes/tailwind-theme-var-define.ts | 1 + web/types/workflow.ts | 53 + 156 files changed, 11158 insertions(+), 5605 deletions(-) delete mode 100644 api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py delete mode 100644 api/core/app/apps/workflow/workflow_event_trigger_callback.py create mode 100644 api/core/app/apps/workflow_app_runner.py delete mode 100644 api/core/app/task_pipeline/workflow_iteration_cycle_manage.py create mode 100644 api/core/workflow/graph_engine/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/base_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_manager.py create mode 100644 api/core/workflow/graph_engine/entities/__init__.py create mode 100644 api/core/workflow/graph_engine/entities/event.py create mode 100644 api/core/workflow/graph_engine/entities/graph.py create mode 100644 api/core/workflow/graph_engine/entities/graph_init_params.py create mode 100644 api/core/workflow/graph_engine/entities/graph_runtime_state.py create mode 100644 api/core/workflow/graph_engine/entities/next_graph_node.py create mode 100644 api/core/workflow/graph_engine/entities/run_condition.py create mode 100644 api/core/workflow/graph_engine/entities/runtime_route_state.py create mode 100644 api/core/workflow/graph_engine/graph_engine.py create mode 100644 api/core/workflow/nodes/answer/answer_stream_generate_router.py create mode 100644 api/core/workflow/nodes/answer/answer_stream_processor.py create mode 100644 api/core/workflow/nodes/answer/base_stream_processor.py create mode 100644 api/core/workflow/nodes/end/end_stream_generate_router.py create mode 100644 api/core/workflow/nodes/end/end_stream_processor.py create mode 100644 api/core/workflow/nodes/event.py create mode 100644 api/core/workflow/nodes/iteration/iteration_start_node.py create mode 100644 api/core/workflow/nodes/node_mapping.py create mode 100644 api/core/workflow/utils/condition/__init__.py create mode 100644 api/core/workflow/utils/condition/entities.py create mode 100644 api/core/workflow/utils/condition/processor.py create mode 100644 api/core/workflow/workflow_entry.py create mode 100644 api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py create mode 100644 web/app/components/workflow/limit-tips.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/next-step/container.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/next-step/operator.tsx create mode 100644 web/app/components/workflow/nodes/iteration-start/constants.ts create mode 100644 web/app/components/workflow/nodes/iteration-start/default.ts create mode 100644 web/app/components/workflow/nodes/iteration-start/index.tsx create mode 100644 web/app/components/workflow/nodes/iteration-start/types.ts delete mode 100644 web/app/components/workflow/nodes/iteration/insert-block.tsx diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 2d540ca584..e03dfeb27c 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.7.3", + default="0.8.0", ) COMMIT_SHA: str = Field( diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index e7c9ebe097..638cc07461 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,12 +4,10 @@ import os import threading import uuid from collections.abc import Generator -from typing import Literal, Union, overload +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) -> dict: ... def generate( - self, app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ): + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True): + stream: bool = True) \ + -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if args.get('inputs') is None: raise ValueError('inputs is required') - extras = { - "auto_generate_conversation_name": False - } - - # get conversation - conversation = None - conversation_id = args.get('conversation_id') - if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) - # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( app_model=app_model, @@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - conversation_id=conversation.id if conversation else None, + conversation_id=None, inputs={}, query='', files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={ + "auto_generate_conversation_name": False + }, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args['inputs'] @@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=None, stream=stream ) def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation | None = None, - stream: bool = True): + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True) \ + -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ is_first_conversation = False if not conversation: is_first_conversation = True @@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - # db.session.refresh(conversation) + db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id ) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - # Create conversation variables if they don't exist. - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] - - session.commit() - - # Increment dialogue count. - conversation.dialogue_count += 1 - - conversation_id = conversation.id - conversation_dialogue_count = conversation.dialogue_count - db.session.commit() - db.session.refresh(conversation) - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - contexts.workflow_variable_pool.set(variable_pool) - # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, + 'conversation_id': conversation.id, 'message_id': message.id, 'context': contextvars.copy_context(), }) @@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, + conversation_id: str, message_id: str, context: contextvars.Context) -> None: """ @@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): var.set(val) with flask_app.app_context(): try: - runner = AdvancedChatAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - # get message - message = self._get_message(message_id) + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) - # chatbot app - runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 5dc03979cf..4da3d093d2 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,49 +1,67 @@ import logging import os -import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueStopEvent, + QueueTextChunkEvent, +) from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from models import App, Message, Workflow +from models.model import App, Conversation, EndUser, Message +from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) -class AdvancedChatAppRunner(AppRunner): +class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ AdvancedChat Application Runner """ - def run( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message, + def __init__( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message ) -> None: """ - Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :param conversation: conversation :param message: message + """ + super().__init__(queue_manager) + + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.message = message + + def run(self) -> None: + """ + Run application :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) app_record = db.session.query(App).filter(App.id == app_config.app_id).first() @@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner): if not workflow: raise ValueError('Workflow not initialized') - inputs = application_generate_entity.inputs - query = application_generate_entity.query + user_id = None + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id - # moderation - if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id, - ): - return + workflow_callbacks: list[WorkflowCallback] = [] + if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + workflow_callbacks.append(WorkflowLoggingCallback()) - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - ): - return + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity + ): + return + + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + self.conversation.dialogue_count += 1 + + conversation_dialogue_count = self.conversation.dialogue_count + db.session.commit() + + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + } + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): - workflow_callbacks.append(WorkflowLoggingCallback()) - # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + ) + + generator = workflow_entry.run( callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') - - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] - - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) def handle_input_moderation( - self, - queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str, + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str ) -> bool: """ Handle input moderation - :param queue_manager: application queue manager :param app_record: app record :param app_generate_entity: application generate entity :param inputs: inputs @@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner): message_id=message_id, ) except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=str(e), - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION ) return True return False - def handle_annotation_reply( - self, - app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity, - ) -> bool: + def handle_annotation_reply(self, app_record: App, + message: Message, + query: str, + app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: """ Handle annotation reply :param app_record: app record :param message: message :param query: query - :param queue_manager: application queue manager :param app_generate_entity: application generate entity """ # annotation reply @@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER + self._publish_event( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) ) - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=annotation_reply.content, - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _stream_output( - self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy - ) -> None: + def _complete_with_stream_output(self, + text: str, + stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param queue_manager: application queue manager :param text: text - :param stream: stream :return: """ - if stream: - index = 0 - for token in text: - queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) + self._publish_event( + QueueTextChunkEvent( + text=text + ) + ) - queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) + self._publish_event( + QueueStopEvent(stopped_by=stopped_by) + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2b3596ded2..fb013cd1b1 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,9 +2,8 @@ import json import logging import time from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union -import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -22,6 +21,9 @@ from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, @@ -31,34 +33,28 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, - ChatflowStreamGenerateRoute, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, + WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.file.file_obj import FileVar -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) @@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: AdvancedChatTaskState + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] - # Deprecated _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, queue_manager: AppQueueManager, conversation: Conversation, @@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._workflow = workflow self._conversation = conversation self._message = message - # Deprecated self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, @@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc SystemVariableKey.USER_ID: user_id, } - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) + self._task_state = WorkflowTaskState() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None def process(self): @@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) + if self._stream: return self._to_stream_response(generator) else: @@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if (message.event - and getattr(message.event, 'metadata', None) - and message.event.metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - elif (hasattr(message.event, 'execution_metadata') - and message.event.execution_metadata - and message.event.execution_metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - event = message.event + # init fake graph runtime state + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + # init workflow run + workflow_run = self._handle_workflow_run_start() + + self._refetch_message() self._message.workflow_run_id = workflow_run.id db.session.commit() @@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: - self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] - # reset current route position to 0 - self._task_state.current_stream_generate_state.current_route_position = 0 + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() - - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - # stream outputs when node finished - generator = self._generate_stream_outputs_when_node_finished() - if generator: - yield from generator + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager + response = self._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution ) - if workflow_run: + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if event.outputs else None, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + elif isinstance(event, QueueWorkflowFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + yield self._error_to_stream_response(self._handle_error(err_event, self._message)) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() - - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message() + self._save_message(graph_runtime_state=graph_runtime_state) yield self._message_end_to_stream_response() + break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) + + self._refetch_message() + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) + + self._refetch_message() + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text yield self._message_to_stream_response(delta_text, self._message.id) elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message(graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() else: continue - if publisher: - publisher.publish(None) + + # publish None when task finished + if tts_publisher: + tts_publisher.publish(None) + if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self) -> None: + def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: """ Save message. :return: """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ if self._task_state.metadata else None - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) - + if graph_runtime_state and graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras['metadata'] = self._task_state.metadata.copy() + + if 'annotation_reply' in extras['metadata']: + del extras['metadata']['annotation_reply'] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, @@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc **extras ) - def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - answer_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.ANSWER.value - ] - - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_config in answer_node_configs: - # get generate route for stream output - answer_node_id = node_config['id'] - generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get answer start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - # check if it's the first node in the iteration - target_node = next((node for node in nodes if node.get('id') == target_node_id), None) - if not target_node: - return [] - - node_iteration_id = target_node.get('data', {}).get('iteration_id') - # get iteration start node id - for node in nodes: - if node.get('id') == node_iteration_id: - if node.get('data', {}).get('start_node_id') == target_node_id: - return [target_node_id] - - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position: - ] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) - if should_direct_answer: - continue - - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - break - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - value = None - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - route_chunk_node_id = value_selector[0] - - if route_chunk_node_id == 'sys': - # system variable - value = contexts.workflow_variable_pool.get().get(value_selector) - if value: - value = value.text - elif route_chunk_node_id in self._iteration_nested_relations: - # it's a iteration variable - if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: - continue - iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] - iterator = iteration_state.inputs - if not iterator: - continue - iterator_selector = iterator.get('iterator_selector', []) - if value_selector[1] == 'index': - value = iteration_state.current_index - elif value_selector[1] == 'item': - value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( - iterator_selector - ) else None - else: - # check chunk node id is before current node id or equal to current node id - if route_chunk_node_id not in self._task_state.ran_node_execution_infos: - break - - latest_node_execution_info = self._task_state.latest_node_execution_info - - # get route chunk node execution info - route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] - if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): - # only LLM support chunk stream output - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - # get route chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id - ).first() - - outputs = route_chunk_node_execution.outputs_dict - - # get value from outputs - value = None - for key in value_selector[1:]: - if not value: - value = outputs.get(key) if outputs else None - else: - value = value.get(key) - - if value is not None: - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - if text: - self._task_state.answer += text - yield self._message_to_stream_response(text, self._message.id) - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return True - - if 'node_id' not in event.metadata: - return True - - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - route_chunk = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position] - - if route_chunk.type != 'var': - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - return False - - return True - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. @@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._output_moderation_handler.append_new_token(text) return False + + def _refetch_message(self) -> None: + """ + Refetch message. + :return: + """ + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if message: + self._message = message diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py deleted file mode 100644 index 8d43155a08..0000000000 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager._publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager._publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self._queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 1165314a7f..a196d36be5 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert(cls, response: Union[ AppBlockingResponse, Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom): + ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 2c5feaaaaf..60216959a8 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time -from collections.abc import Generator -from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -347,7 +347,7 @@ class AppRunner: self, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, ) -> tuple[bool, dict, str]: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 26bb6c0f4f..4347e5277b 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Literal, Union, overload +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator): args: dict, invoke_from: InvokeFrom, stream: Literal[True] = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> Generator[str, None, None]: ... @overload @@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator): args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> dict: ... def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ): """ Generate App response. @@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param invoke_from: invoke from source :param stream: is stream :param call_depth: call depth + :param workflow_thread_pool_id: workflow thread pool id """ inputs = args['inputs'] @@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, + workflow_thread_pool_id=workflow_thread_pool_id ) def _generate( - self, app_model: App, + self, *, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + workflow_thread_pool_id: Optional[str] = None + ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param application_generate_entity: application generate entity :param invoke_from: invoke from source :param stream: is stream + :param workflow_thread_pool_id: workflow thread pool id """ # init queue manager queue_manager = WorkflowAppQueueManager( @@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator): # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), + 'workflow_thread_pool_id': workflow_thread_pool_id }) worker_thread.start() @@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True): + stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator): if args.get('inputs') is None: raise ValueError('inputs is required') - extras = { - "auto_generate_conversation_name": False - } - # convert to app config app_config = WorkflowAppConfigManager.get_app_config( app_model=app_model, @@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={ + "auto_generate_conversation_name": False + }, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args['inputs'] @@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator): def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, - context: contextvars.Context) -> None: + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None) -> None: """ Generate worker in a new thread. :param flask_app: Flask app :param application_generate_entity: application generate entity :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id :return: """ for var, val in context.items(): @@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator): with flask_app.app_context(): try: # workflow app - runner = WorkflowAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id + ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: @@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e388d0184b..9d48db7546 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,46 +4,61 @@ from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig -from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import WorkflowType logger = logging.getLogger(__name__) -class WorkflowAppRunner: +class WorkflowAppRunner(WorkflowBasedAppRunner): """ Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: @@ -53,80 +68,64 @@ class WorkflowAppRunner: if not workflow: raise ValueError('Workflow not initialized') - inputs = application_generate_entity.inputs - files = application_generate_entity.files - db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - + workflow_callbacks: list[WorkflowCallback] = [] if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) - # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: + + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') - - if not app_record.workflow_id: - raise ValueError('Workflow not initialized') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') - - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] - - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks + generator = workflow_entry.run( + callbacks=workflow_callbacks ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index de8542d7b9..00b3b9f57e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -15,10 +16,12 @@ from core.app.entities.queue_entities import ( QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, - QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, @@ -32,19 +35,16 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, - TextReplaceStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowStreamGenerateNodes, + WorkflowStartStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -52,8 +52,8 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, + WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa SystemVariableKey.USER_ID: user_id } - self._task_state = WorkflowTaskState( - iteration_nested_node_ids=[] - ) - self._stream_generate_nodes = self._get_stream_generate_nodes() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) + self._task_state = WorkflowTaskState() def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status, - outputs=workflow_run.outputs_dict, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at) ) ) @@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa To stream response. :return: """ + workflow_run_id = None for stream_response in generator: + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id + yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run_id, + workflow_run_id=workflow_run_id, stream_response=stream_response ) @@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: - self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() - - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - yield self._workflow_node_finish_to_stream_response( + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) + + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, trace_manager=trace_manager + response = self._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, ) # save workflow app log @@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response(delta_text) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._text_replace_to_stream_response(event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() else: continue - if publisher: - publisher.publish(None) + if tts_publisher: + tts_publisher.publish(None) def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: @@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # not save log for debugging return - workflow_app_log = WorkflowAppLog( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, - created_from=created_from.value, - created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), - created_by=self._user.id, - ) + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by = self._user.id + db.session.add(workflow_app_log) db.session.commit() db.session.close() @@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa ) return response - - def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: - """ - Text replace to stream response. - :param text: text - :return: - """ - return TextReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - text=TextReplaceStreamResponse.Data(text=text) - ) - - def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: - """ - Get stream generate nodes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - end_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.END.value - ] - - # parse stream output node value selectors of end nodes - stream_generate_routes = {} - for node_config in end_node_configs: - # get generate route for stream output - end_node_id = node_config['id'] - generate_nodes = EndNode.extract_generate_nodes(graph, node_config) - start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( - end_node_id=end_node_id, - stream_node_ids=generate_nodes - ) - - return stream_generate_routes - - def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get end start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids - - for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): - if node_id not in stream_node_ids: - continue - - node_execution_info = self._task_state.ran_node_execution_infos[node_id] - - # get chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() - - if not route_chunk_node_execution: - continue - - outputs = route_chunk_node_execution.outputs_dict - - if not outputs: - continue - - # get value from outputs - text = outputs.get('text') - - if text: - self._task_state.answer += text - yield self._text_chunk_to_stream_response(text) - - db.session.close() - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return False - - if 'node_id' not in event.metadata: - return False - - node_id = event.metadata.get('node_id') - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - return True - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py deleted file mode 100644 index 4472a7e9b5..0000000000 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager.publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager.publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - pass diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 0000000000..1709726887 --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,379 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + # init graph + graph = Graph.init( + graph_config=graph_config + ) + + if not graph: + raise ValueError('graph not found in workflow') + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError('workflow graph not found') + + graph_config = cast(dict[str, Any], graph_config) + + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + + # filter nodes only in iteration + node_configs = [ + node for node in graph_config.get('nodes', []) + if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + ] + + graph_config['nodes'] = node_configs + + node_ids = [node.get('id') for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge for edge in graph_config.get('edges', []) + if (edge.get('source') is None or edge.get('source') in node_ids) + and (edge.get('target') is None or edge.get('target') in node_ids) + ] + + graph_config['edges'] = edge_configs + + # init graph + graph = Graph.init( + graph_config=graph_config, + root_node_id=node_id + ) + + if not graph: + raise ValueError('graph not found in workflow') + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get('id') == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError('iteration node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get('data', {})) + ) + + return graph, variable_pool + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent( + graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state + ) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event( + QueueWorkflowSucceededEvent(outputs=event.outputs) + ) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event( + QueueWorkflowFailedEvent(error=event.error) + ) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result else {}, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result + and event.route_node_state.node_run_result.error + else "Unknown error", + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content, + from_variable_selector=event.from_variable_selector, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunSucceededEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + error=event.error + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish( + event, + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 2e6431d6d0..4e8f3644b1 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -1,10 +1,24 @@ from typing import Optional -from core.app.entities.queue_entities import AppQueueEvent from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: self.current_node_id = None - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self.print_text("\n[on_workflow_run_started]", color='pink') + def on_event( + self, + event: GraphEngineEvent + ) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[GraphRunStartedEvent]", color='pink') + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[GraphRunSucceededEvent]", color='green') + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started( + event=event + ) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded( + event=event + ) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed( + event=event + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk( + event=event + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self.on_workflow_parallel_started( + event=event + ) + elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): + self.on_workflow_parallel_completed( + event=event + ) + elif isinstance(event, IterationRunStartedEvent): + self.on_workflow_iteration_started( + event=event + ) + elif isinstance(event, IterationRunNextEvent): + self.on_workflow_iteration_next( + event=event + ) + elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): + self.on_workflow_iteration_completed( + event=event + ) + else: + self.print_text(f"\n[{event.__class__.__name__}]", color='blue') - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self.print_text("\n[on_workflow_run_succeeded]", color='green') - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self.print_text("\n[on_workflow_run_failed]", color='red') - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: + def on_workflow_node_execute_started( + self, + event: NodeRunStartedEvent + ) -> None: """ Workflow node execute started """ - self.print_text("\n[on_workflow_node_execute_started]", color='yellow') - self.print_text(f"Node ID: {node_id}", color='yellow') - self.print_text(f"Type: {node_type.value}", color='yellow') - self.print_text(f"Index: {node_run_index}", color='yellow') - if predecessor_node_id: - self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') + self.print_text("\n[NodeRunStartedEvent]", color='yellow') + self.print_text(f"Node ID: {event.node_id}", color='yellow') + self.print_text(f"Node Title: {event.node_data.title}", color='yellow') + self.print_text(f"Type: {event.node_type.value}", color='yellow') - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: + def on_workflow_node_execute_succeeded( + self, + event: NodeRunSucceededEvent + ) -> None: """ Workflow node execute succeeded """ - self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') - self.print_text(f"Node ID: {node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') - self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", - color='green') + route_node_state = event.route_node_state - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: + self.print_text("\n[NodeRunSucceededEvent]", color='green') + self.print_text(f"Node ID: {event.node_id}", color='green') + self.print_text(f"Node Title: {event.node_data.title}", color='green') + self.print_text(f"Type: {event.node_type.value}", color='green') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color='green') + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color='green') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color='green') + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", + color='green') + + def on_workflow_node_execute_failed( + self, + event: NodeRunFailedEvent + ) -> None: """ Workflow node execute failed """ - self.print_text("\n[on_workflow_node_execute_failed]", color='red') - self.print_text(f"Node ID: {node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') - self.print_text(f"Error: {error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') + route_node_state = event.route_node_state - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + self.print_text("\n[NodeRunFailedEvent]", color='red') + self.print_text(f"Node ID: {event.node_id}", color='red') + self.print_text(f"Node Title: {event.node_data.title}", color='red') + self.print_text(f"Type: {event.node_type.value}", color='red') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color='red') + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color='red') + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color='red') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color='red') + + def on_node_text_chunk( + self, + event: NodeRunStreamChunkEvent + ) -> None: """ Publish text chunk """ - if not self.current_node_id or self.current_node_id != node_id: - self.current_node_id = node_id - self.print_text('\n[on_node_text_chunk]') - self.print_text(f"Node ID: {node_id}") - self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id + self.print_text('\n[NodeRunStreamChunkEvent]') + self.print_text(f"Node ID: {route_node_state.node_id}") - self.print_text(text, color="pink", end="") + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: + self.print_text(event.chunk_content, color="pink", end="") + + def on_workflow_parallel_started( + self, + event: ParallelBranchRunStartedEvent + ) -> None: + """ + Publish parallel started + """ + self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') + self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') + + def on_workflow_parallel_completed( + self, + event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + ) -> None: + """ + Publish parallel completed + """ + if isinstance(event, ParallelBranchRunSucceededEvent): + color = 'blue' + elif isinstance(event, ParallelBranchRunFailedEvent): + color = 'red' + + self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) + self.print_text(f"Parallel ID: {event.parallel_id}", color=color) + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) + + if isinstance(event, ParallelBranchRunFailedEvent): + self.print_text(f"Error: {event.error}", color=color) + + def on_workflow_iteration_started( + self, + event: IterationRunStartedEvent + ) -> None: """ Publish iteration started """ - self.print_text("\n[on_workflow_iteration_started]", color='blue') - self.print_text(f"Node ID: {node_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color='blue') + self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[dict]) -> None: + def on_workflow_iteration_next( + self, + event: IterationRunNextEvent + ) -> None: """ Publish iteration next """ - self.print_text("\n[on_workflow_iteration_next]", color='blue') + self.print_text("\n[IterationRunNextEvent]", color='blue') + self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') + self.print_text(f"Iteration Index: {event.index}", color='blue') - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: + def on_workflow_iteration_completed( + self, + event: IterationRunSucceededEvent | IterationRunFailedEvent + ) -> None: """ Publish iteration completed """ - self.print_text("\n[on_workflow_iteration_completed]", color='blue') - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self.print_text("\n[on_workflow_event]", color='blue') - self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') + self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') + self.print_text(f"Node ID: {event.iteration_id}", color='blue') def print_text( self, text: str, color: Optional[str] = None, end: str = "\n" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15348251f2..4c86b7eee1 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,3 +1,4 @@ +from datetime import datetime from enum import Enum from typing import Any, Optional @@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class QueueEvent(str, Enum): @@ -31,6 +33,9 @@ class QueueEvent(str, Enum): ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" MESSAGE_FILE = "message_file" + PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" + PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" + PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" ERROR = "error" PING = "ping" STOP = "stop" @@ -38,7 +43,7 @@ class QueueEvent(str, Enum): class AppQueueEvent(BaseModel): """ - QueueEvent entity + QueueEvent abstract entity """ event: QueueEvent @@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel): class QueueLLMChunkEvent(AppQueueEvent): """ QueueLLMChunkEvent entity + Only for basic mode apps """ event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk @@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent): QueueIterationStartEvent entity """ event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: dict = None + inputs: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict] = None + metadata: Optional[dict[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): """ @@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent): event: QueueEvent = QueueEvent.ITERATION_NEXT index: int + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ - event:QueueEvent = QueueEvent.ITERATION_COMPLETED + event: QueueEvent = QueueEvent.ITERATION_COMPLETED + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - outputs: dict + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + error: Optional[str] = None + class QueueTextChunkEvent(AppQueueEvent): """ @@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - metadata: Optional[dict] = None + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAgentMessageEvent(AppQueueEvent): @@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAnnotationReplyEvent(AppQueueEvent): @@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): QueueWorkflowStartedEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): QueueWorkflowSucceededEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_STARTED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 predecessor_node_id: Optional[str] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime class QueueNodeSucceededEvent(AppQueueEvent): @@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_SUCCEEDED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None - execution_metadata: Optional[dict] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None @@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_FAILED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime - inputs: Optional[dict] = None - outputs: Optional[dict] = None - process_data: Optional[dict] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None error: str @@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent): event: QueueEvent = QueueEvent.STOP stopped_by: StopBy + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', + QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', + QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', + QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + } + + return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + class QueueMessage(BaseModel): """ - QueueMessage entity + QueueMessage abstract entity """ task_id: str app_mode: str @@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage): WorkflowQueueMessage entity """ pass + + +class QueueParallelBranchRunStartedEvent(AppQueueEvent): + """ + QueueParallelBranchRunStartedEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunSucceededEvent(AppQueueEvent): + """ + QueueParallelBranchRunSucceededEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunFailedEvent(AppQueueEvent): + """ + QueueParallelBranchRunFailedEvent entity + """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7bc5598984..7cab6ca4e0 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -3,40 +3,11 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import GenerateRouteChunk from models.workflow import WorkflowNodeExecutionStatus -class WorkflowStreamGenerateNodes(BaseModel): - """ - WorkflowStreamGenerateNodes entity - """ - end_node_id: str - stream_node_ids: list[str] - - -class ChatflowStreamGenerateRoute(BaseModel): - """ - ChatflowStreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - -class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class TaskState(BaseModel): """ TaskState entity @@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState): """ answer: str = "" - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None - - iteration_nested_node_ids: list[str] = None - - -class AdvancedChatTaskState(WorkflowTaskState): - """ - AdvancedChatTaskState entity - """ - usage: LLMUsage - - current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None - class StreamEvent(Enum): """ @@ -97,6 +47,8 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + PARALLEL_BRANCH_STARTED = "parallel_branch_started" + PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse): inputs: Optional[dict] = None created_at: int extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse): "predecessor_node_id": self.data.predecessor_node_id, "inputs": None, "created_at": self.data.created_at, - "extras": {} + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, } } @@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse): created_at: int finished_at: int files: Optional[list[dict]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse): "execution_metadata": None, "created_at": self.data.created_at, "finished_at": self.data.finished_at, - "files": [] + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, } } + + +class ParallelBranchStartStreamResponse(StreamResponse): + """ + ParallelBranchStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED + workflow_run_id: str + data: Data + + +class ParallelBranchFinishedStreamResponse(StreamResponse): + """ + ParallelBranchFinishedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + status: str + error: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED + workflow_run_id: str + data: Data class IterationNodeStartStreamResponse(StreamResponse): @@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse): extras: dict = {} metadata: dict = {} inputs: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse): created_at: int pre_iteration_output: Optional[Any] = None extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): title: str outputs: Optional[dict] = None created_at: int - extras: dict = None - inputs: dict = None + extras: Optional[dict] = None + inputs: Optional[dict] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float @@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None finished_at: int steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ - workflow_run_id: str + workflow_run_id: Optional[str] = None class AppBlockingResponse(BaseModel): @@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data - - -class WorkflowIterationState(BaseModel): - """ - WorkflowIterationState entity - """ - - class Data(BaseModel): - """ - Data entity - """ - parent_iteration_id: Optional[str] = None - iteration_id: str - current_index: int - iteration_steps_boundary: list[int] = None - node_execution_id: str - started_at: float - inputs: dict = None - total_tokens: int = 0 - node_data: BaseNodeData - - current_iterations: dict[str, Data] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a3c1fb5824..2f74a180d1 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline: err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) if message: - message = db.session.query(Message).filter(Message.id == message.id).first() - err_desc = self._error_to_desc(err) - message.status = 'error' - message.error = err_desc + refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.commit() + if refetch_message: + err_desc = self._error_to_desc(err) + refetch_message.status = 'error' + refetch_message.error = err_desc + + db.session.commit() return err - def _error_to_desc(cls, e: Exception) -> str: + def _error_to_desc(self, e: Exception) -> str: """ Error to desc. :param e: exception diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 76c50809cf..8ff50dd174 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, @@ -16,11 +15,11 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager @@ -36,7 +35,7 @@ class MessageCycleManage: AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] - _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + _task_state: Union[EasyUITaskState, WorkflowTaskState] def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: """ @@ -45,6 +44,9 @@ class MessageCycleManage: :param query: query :return: thread """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) @@ -52,7 +54,7 @@ class MessageCycleManage: if auto_generate_conversation_name and is_first_message: # start generate thread thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'conversation_id': conversation.id, 'query': query }) @@ -75,6 +77,9 @@ class MessageCycleManage: .first() ) + if not conversation: + return + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: @@ -121,34 +126,13 @@ class MessageCycleManage: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata['retriever_resources'] = event.retriever_resources - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event :return: """ - message_file: MessageFile = ( + message_file = ( db.session.query(MessageFile) .filter(MessageFile.id == event.message_file_id) .first() diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4935c43ac4..ed3225310a 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,33 +1,41 @@ import json import time from datetime import datetime, timezone -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, ) from core.app.entities.task_entities import ( - NodeExecutionInfo, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, + WorkflowTaskState, ) -from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -41,54 +49,56 @@ from models.workflow import ( WorkflowRunStatus, WorkflowRunTriggeredFrom, ) -from services.workflow_service import WorkflowService -class WorkflowCycleManage(WorkflowIterationCycleManage): - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _workflow_system_variables: dict[SystemVariableKey, Any] + + def _handle_workflow_run_start(self) -> WorkflowRun: + max_sequence = ( + db.session.query(db.func.max(WorkflowRun.sequence_number)) + .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) + .filter(WorkflowRun.app_id == self._workflow.app_id) + .scalar() + or 0 + ) new_sequence_number = max_sequence + 1 - inputs = {**user_inputs} - for key, value in (system_inputs or {}).items(): + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): if key.value == 'conversation': continue inputs[f'sys.{key.value}'] = value - inputs = WorkflowEngineManager.handle_special_values(inputs) + + inputs = WorkflowEntry.handle_special_values(inputs) + + triggered_from= ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps(inputs), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id + workflow_run = WorkflowRun() + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING.value + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value ) + workflow_run.created_by = self._user.id db.session.add(workflow_run) db.session.commit() @@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_run - def _workflow_run_success( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_success( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, outputs: Optional[str] = None, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run success :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param outputs: outputs :param conversation_id: conversation id :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.outputs = outputs - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) - db.session.close() if trace_manager: trace_manager.add_trace_task( @@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): ) ) + db.session.close() + return workflow_run - def _workflow_run_failed( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_failed( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, error: str, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run failed :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param status: status :param error: error message :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = status.value workflow_run.error = error - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() + + running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value + ).all() + + for workflow_node_execution in running_workflow_node_executions: + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + db.session.commit() + db.session.refresh(workflow_run) db.session.close() @@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_run - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ + def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.add(workflow_node_execution) db.session.commit() @@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: """ Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata + :param event: queue node succeeded event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None - ) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: """ Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message + :param event: queue node failed event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.error = event.error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_start_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + ################################################# + # to stream responses # + ################################################# + + def _workflow_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, - created_at=int(workflow_run.created_at.timestamp()) - ) + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + ), ) - def _workflow_finish_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + 'id': created_by_account.id, + 'name': created_by_account.name, + 'email': created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, + 'id': created_by_end_user.id, + 'user': created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + ), ) - def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution) \ - -> NodeStartStreamResponse: + def _workflow_node_start_to_stream_response( + self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeStartStreamResponse]: """ Workflow node start to stream response. :param event: queue node started event @@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()) - ) + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) # extras logic @@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): response.data.extras['icon'] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, - provider_id=node_data.provider_id + provider_id=node_data.provider_id, ) return response - def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + def _workflow_node_finish_to_stream_response( + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. + :param event: queue node succeeded or failed event :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), + ) + + def _workflow_parallel_branch_start_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: + """ + Workflow parallel branch start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run started event + :return: + """ + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + created_at=int(time.time()), + ) + ) + + def _workflow_parallel_branch_finished_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent + ) -> ParallelBranchFinishedStreamResponse: + """ + Workflow parallel branch finished to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run succeeded or failed event + :return: + """ + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), ) ) - def _handle_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs=self._workflow_system_variables + def _workflow_iteration_start_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + """ + Workflow iteration start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration start event + :return: + """ + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) ) - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id + def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + """ + Workflow iteration next to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration next event + :return: + """ + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) ) - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - - execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None - - if self._iteration_state and self._iteration_state.current_iterations: - if not execution_metadata: - execution_metadata = {} - current_iteration_data = None - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if data.parent_iteration_id == None: - current_iteration_data = data - break - - if current_iteration_data: - execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id - execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index - - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, + def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + """ + Workflow iteration completed to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration completed event + :return: + """ + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, outputs=event.outputs, - execution_metadata=execution_metadata + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) - - if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if self._iteration_state: - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=execution_metadata - ) - - db.session.close() - - return workflow_node_execution - - def _handle_workflow_finished( - self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None - ) -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - - if conversation_id is None: - conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - latest_node_execution_info = self._task_state.latest_node_execution_info - if latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() - if (workflow_node_execution - and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=latest_node_execution_info.start_at, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run + ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: """ @@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return value.to_dict() return None + + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Refetch workflow run + :param workflow_run_id: workflow run id + :return: + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.id == workflow_run_id).first() + + if not workflow_run: + raise Exception(f'Workflow run not found: {workflow_run_id}') + + return workflow_run + + def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + """ + Refetch workflow node execution + :param node_execution_id: workflow node execution id + :return: + """ + workflow_node_execution = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, + WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, + WorkflowNodeExecution.workflow_id == self._workflow.id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.node_execution_id == node_execution_id, + ) + .first() + ) + + if not workflow_node_execution: + raise Exception(f'Workflow node execution not found: {node_execution_id}') + + return workflow_node_execution \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index bd98c82720..e69de29bb2 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -1,16 +0,0 @@ -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariableKey -from models.account import Account -from models.model import EndUser -from models.workflow import Workflow - - -class WorkflowCycleStateManager: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py deleted file mode 100644 index aff1870714..0000000000 --- a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import time -from collections.abc import Generator -from datetime import datetime, timezone -from typing import Optional, Union - -from core.app.entities.queue_entities import ( - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, -) -from core.app.entities.task_entities import ( - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - NodeExecutionInfo, - WorkflowIterationState, -) -from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.workflow_engine_manager import WorkflowEngineManager -from extensions.ext_database import db -from models.workflow import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, -) - - -class WorkflowIterationCycleManage(WorkflowCycleStateManager): - _iteration_state: WorkflowIterationState = None - - def _init_iteration_state(self) -> WorkflowIterationState: - if not self._iteration_state: - self._iteration_state = WorkflowIterationState( - current_iterations={} - ) - - def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ - -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: - """ - Handle iteration to stream response - :param task_id: task id - :param event: iteration event - :return: - """ - if isinstance(event, QueueIterationStartEvent): - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs, - metadata=event.metadata - ) - ) - elif isinstance(event, QueueIterationNextEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={} - ) - ) - elif isinstance(event, QueueIterationCompletedEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - error=None, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) - - def _init_iteration_execution_from_workflow_run(self, - workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - inputs=json.dumps(inputs) if inputs else None, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - execution_metadata=json.dumps({ - 'started_run_index': node_run_index + 1, - 'current_index': 0, - 'steps_boundary': [], - }), - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: - if isinstance(event, QueueIterationStartEvent): - return self._handle_iteration_started(event) - elif isinstance(event, QueueIterationNextEvent): - return self._handle_iteration_next(event) - elif isinstance(event, QueueIterationCompletedEvent): - return self._handle_iteration_completed(event) - - def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: - self._init_iteration_state() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_iteration_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=NodeType.ITERATION, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( - parent_iteration_id=None, - iteration_id=event.node_id, - current_index=0, - iteration_steps_boundary=[], - node_execution_id=workflow_node_execution.id, - started_at=time.perf_counter(), - inputs=event.inputs, - total_tokens=0, - node_data=event.node_data - ) - - db.session.close() - - return workflow_node_execution - - def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: - if event.node_id not in self._iteration_state.current_iterations: - return - current_iteration = self._iteration_state.current_iterations[event.node_id] - current_iteration.current_index = event.index - current_iteration.iteration_steps_boundary.append(event.node_run_index) - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['current_index'] = event.index - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - db.session.close() - - def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): - if event.node_id not in self._iteration_state.current_iterations: - return - - current_iteration = self._iteration_state.current_iterations[event.node_id] - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - # remove current iteration - self._iteration_state.current_iterations.pop(event.node_id, None) - - # set latest node execution info - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.latest_node_execution_info = latest_node_execution_info - - db.session.close() - - def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: - """ - Handle iteration exception - """ - if not self._iteration_state or not self._iteration_state.current_iterations: - return - - for node_id, current_iteration in self._iteration_state.current_iterations.items(): - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - db.session.commit() - db.session.close() - - yield IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=node_id, - node_id=node_id, - node_type=NodeType.ITERATION.value, - title=current_iteration.node_data.title, - outputs={}, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a..59a4c103a2 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -63,6 +63,39 @@ class LLMUsage(ModelUsage): latency=0.0 ) + def plus(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Add two LLMUsage instances together. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + if self.total_tokens == 0: + return other + else: + return LLMUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + prompt_unit_price=other.prompt_unit_price, + prompt_price_unit=other.prompt_price_unit, + prompt_price=self.prompt_price + other.prompt_price, + completion_tokens=self.completion_tokens + other.completion_tokens, + completion_unit_price=other.completion_unit_price, + completion_price_unit=other.completion_price_unit, + completion_price=self.completion_price + other.completion_price, + total_tokens=self.total_tokens + other.total_tokens, + total_price=self.total_price + other.total_price, + currency=other.currency, + latency=self.latency + other.latency + ) + + def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Overload the + operator to add two LLMUsage instances. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + return self.plus(other) class LLMResult(BaseModel): """ diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 9a4d8db4e2..69e28770c3 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -34,13 +34,13 @@ class OutputModeration(BaseModel): final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def should_direct_output(self): + def should_direct_output(self) -> bool: return self.final_output is not None - def get_final_output(self): - return self.final_output + def get_final_output(self) -> str: + return self.final_output or "" - def append_new_token(self, token: str): + def append_new_token(self, token: str) -> None: self.buffer += token if not self.thread: diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 12e498e76d..15e915628e 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,7 +1,7 @@ import json import logging from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.file.file_obj import FileTransferMethod, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType @@ -18,6 +18,7 @@ class WorkflowTool(Tool): version: str workflow_entities: dict[str, Any] workflow_call_depth: int + thread_pool_id: Optional[str] = None label: str @@ -57,6 +58,7 @@ class WorkflowTool(Tool): invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, + workflow_thread_pool_id=self.thread_pool_id ) data = result.get('data', {}) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0e15151aa4..6c0e906628 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -128,6 +128,7 @@ class ToolEngine: user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, + thread_pool_id: Optional[str] = None ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. @@ -141,6 +142,7 @@ class ToolEngine: if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4a0188af49..4778d79ed9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -249,7 +248,7 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the workflow tool runtime """ diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 564b9d3e14..23e7c0c243 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 6db8adf4c2..9015eea85c 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,116 +1,15 @@ from abc import ABC, abstractmethod -from typing import Any, Optional -from core.app.entities.queue_entities import AppQueueEvent -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self) -> None: + def on_event( + self, + event: GraphEngineEvent + ) -> None: """ - Workflow run started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - raise NotImplementedError - - @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any], - ) -> None: - """ - Publish iteration next - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - raise NotImplementedError - - @abstractmethod - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event + Published event """ raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 6bf0c11c7d..e7e6710cbd 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel): desc: Optional[str] = None class BaseIterationNodeData(BaseNodeData): - start_node_id: str + start_node_id: Optional[str] = None class BaseIterationState(BaseModel): iteration_node_id: str diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 025453567b..5e2a5cb466 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMUsage from models import WorkflowNodeExecutionStatus @@ -28,6 +28,7 @@ class NodeType(Enum): VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' + ITERATION_START = 'iteration-start' # fake start node for iteration PARAMETER_EXTRACTOR = 'parameter-extractor' CONVERSATION_VARIABLE_ASSIGNER = 'assigner' @@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum): TOOL_INFO = 'tool_info' ITERATION_ID = 'iteration_id' ITERATION_INDEX = 'iteration_index' + PARALLEL_ID = 'parallel_id' + PARALLEL_START_NODE_ID = 'parallel_start_node_id' + PARENT_PARALLEL_ID = 'parent_parallel_id' + PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' class NodeRunResult(BaseModel): @@ -65,11 +70,32 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs + inputs: Optional[dict[str, Any]] = None # node inputs + process_data: Optional[dict[str, Any]] = None # process data + outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + + +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 27d0b672f6..48a20d25ae 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -2,6 +2,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union +from pydantic import BaseModel, Field, model_validator from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory @@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" -class VariablePool: - def __init__( - self, - system_variables: Mapping[SystemVariableKey, Any], - user_inputs: Mapping[str, Any], - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable] | None = None, - ) -> None: - # system variables - # for example: - # { - # 'query': 'abc', - # 'files': [] - # } +class VariablePool(BaseModel): + # Variable dictionary is a dictionary for looking up variables by their selector. + # The first element of the selector is the node id, it's the first-level key in the dictionary. + # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the + # elements of the selector except the first one. + variable_dictionary: dict[str, dict[int, Segment]] = Field( + description='Variables mapping', + default=defaultdict(dict) + ) - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) + # TODO: This user inputs is not used for pool. + user_inputs: Mapping[str, Any] = Field( + description='User inputs', + ) - # TODO: This user inputs is not used for pool. - self.user_inputs = user_inputs + system_variables: Mapping[SystemVariableKey, Any] = Field( + description='System variables', + ) + environment_variables: Sequence[Variable] = Field( + description="Environment variables.", + default_factory=list + ) + + conversation_variables: Sequence[Variable] | None = None + + @model_validator(mode="after") + def val_model_after(self): + """ + Append system variables + :return: + """ # Add system variables to the variable pool - self.system_variables = system_variables - for key, value in system_variables.items(): + for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables: + for var in self.environment_variables or []: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) # Add conversation variables to the variable pool - for var in conversation_variables or []: + for var in self.conversation_variables or []: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + return self + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. @@ -79,7 +89,7 @@ class VariablePool: v = factory.build_segment(value) hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = v def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -97,7 +107,7 @@ class VariablePool: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value @@ -118,7 +128,7 @@ class VariablePool: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None def remove(self, selector: Sequence[str], /): @@ -134,7 +144,19 @@ class VariablePool: if not selector: return if len(selector) == 1: - self._variable_dictionary[selector[0]] = {} + self.variable_dictionary[selector[0]] = {} return hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[selector[0]].pop(hash_key, None) + + def remove_node(self, node_id: str, /): + """ + Remove all variables associated with a given node id. + + Args: + node_id (str): The node id to remove. + + Returns: + None + """ + self.variable_dictionary.pop(node_id, None) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 9b35b8df8a..4bf4e454bb 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -66,8 +66,7 @@ class WorkflowRunState: self.variable_pool = variable_pool self.total_tokens = 0 - self.workflow_nodes_and_results = [] - self.current_iteration_state = None self.workflow_node_steps = 1 - self.workflow_node_runs = [] \ No newline at end of file + self.workflow_node_runs = [] + self.current_iteration_state = None diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index fe79fadf66..07cbcd981e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,10 +1,8 @@ -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): - self.node_id = node_id - self.node_type = node_type - self.node_title = node_title + def __init__(self, node_instance: BaseNode, error: str): + self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_title} run failed: {error}") + super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/__init__.py b/api/core/workflow/graph_engine/condition_handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py new file mode 100644 index 0000000000..4099def4e2 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class RunConditionHandler(ABC): + def __init__(self, + init_params: GraphInitParams, + graph: Graph, + condition: RunCondition): + self.init_params = init_params + self.graph = graph + self.condition = condition + + @abstractmethod + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState + ) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py new file mode 100644 index 0000000000..705eb908b1 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -0,0 +1,28 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class BranchIdentifyRunConditionHandler(RunConditionHandler): + + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.branch_identify: + raise Exception("Branch identify is required") + + run_result = previous_route_node_state.node_run_result + if not run_result: + return False + + if not run_result.edge_source_handle: + return False + + return self.condition.branch_identify == run_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py new file mode 100644 index 0000000000..1edaf92da7 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -0,0 +1,32 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.utils.condition.processor import ConditionProcessor + + +class ConditionRunConditionHandlerHandler(RunConditionHandler): + def check(self, + graph_runtime_state: GraphRuntimeState, + previous_route_node_state: RouteNodeState + ) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.conditions: + return True + + # process condition + condition_processor = ConditionProcessor() + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, + conditions=self.condition.conditions + ) + + # Apply the logical operator for the current case + compare_result = all(group_result) + + return compare_result diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py new file mode 100644 index 0000000000..2eb2e58bfc --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -0,0 +1,35 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class ConditionManager: + @staticmethod + def get_condition_handler( + init_params: GraphInitParams, + graph: Graph, + run_condition: RunCondition + ) -> RunConditionHandler: + """ + Get condition handler + + :param init_params: init params + :param graph: graph + :param run_condition: run condition + :return: condition handler + """ + if run_condition.type == "branch_identify": + return BranchIdentifyRunConditionHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) + else: + return ConditionRunConditionHandlerHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py new file mode 100644 index 0000000000..06dc4cb8f4 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/event.py @@ -0,0 +1,163 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class GraphEngineEvent(BaseModel): + pass + + +########################################### +# Graph Events +########################################### + + +class BaseGraphEvent(GraphEngineEvent): + pass + + +class GraphRunStartedEvent(BaseGraphEvent): + pass + + +class GraphRunSucceededEvent(BaseGraphEvent): + outputs: Optional[dict[str, Any]] = None + """outputs""" + + +class GraphRunFailedEvent(BaseGraphEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Node Events +########################################### + + +class BaseNodeEvent(GraphEngineEvent): + id: str = Field(..., description="node execution id") + node_id: str = Field(..., description="node id") + node_type: NodeType = Field(..., description="node type") + node_data: BaseNodeData = Field(..., description="node data") + route_node_state: RouteNodeState = Field(..., description="route node state") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class NodeRunStartedEvent(BaseNodeEvent): + predecessor_node_id: Optional[str] = None + """predecessor node id""" + + +class NodeRunStreamChunkEvent(BaseNodeEvent): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + + +class NodeRunRetrieverResourceEvent(BaseNodeEvent): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +class NodeRunSucceededEvent(BaseNodeEvent): + pass + + +class NodeRunFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + +########################################### +# Parallel Branch Events +########################################### + + +class BaseParallelBranchEvent(GraphEngineEvent): + parallel_id: str = Field(..., description="parallel id") + """parallel id""" + parallel_start_node_id: str = Field(..., description="parallel start node id") + """parallel start node id""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Iteration Events +########################################### + + +class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration node execution id") + iteration_node_id: str = Field(..., description="iteration node id") + iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") + iteration_node_data: BaseNodeData = Field(..., description="node data") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + + +class IterationRunStartedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + predecessor_node_id: Optional[str] = None + + +class IterationRunNextEvent(BaseIterationEvent): + index: int = Field(..., description="index") + pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + + +class IterationRunSucceededEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + +class IterationRunFailedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + error: str = Field(..., description="failed reason") + + +InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py new file mode 100644 index 0000000000..49007b870d --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -0,0 +1,692 @@ +import uuid +from collections.abc import Mapping +from typing import Any, Optional, cast + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter +from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute +from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter +from core.workflow.nodes.end.entities import EndStreamParam + + +class GraphEdge(BaseModel): + source_node_id: str = Field(..., description="source node id") + target_node_id: str = Field(..., description="target node id") + run_condition: Optional[RunCondition] = None + """run condition""" + + +class GraphParallel(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") + start_from_node_id: str = Field(..., description="start from node id") + parent_parallel_id: Optional[str] = None + """parent parallel id""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id""" + end_to_node_id: Optional[str] = None + """end to node id""" + + +class Graph(BaseModel): + root_node_id: str = Field(..., description="root node id of the graph") + node_ids: list[str] = Field(default_factory=list, description="graph node ids") + node_id_config_mapping: dict[str, dict] = Field( + default_factory=list, + description="node configs mapping (node id: node config)" + ) + edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="graph edge mapping (source node id: edges)" + ) + reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="reverse graph edge mapping (target node id: edges)" + ) + parallel_mapping: dict[str, GraphParallel] = Field( + default_factory=dict, + description="graph parallel mapping (parallel id: parallel)" + ) + node_parallel_mapping: dict[str, str] = Field( + default_factory=dict, + description="graph node parallel mapping (node id: parallel id)" + ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( + ..., + description="answer stream generate routes" + ) + end_stream_param: EndStreamParam = Field( + ..., + description="end stream param" + ) + + @classmethod + def init(cls, + graph_config: Mapping[str, Any], + root_node_id: Optional[str] = None) -> "Graph": + """ + Init graph + + :param graph_config: graph config + :param root_node_id: root node id + :return: graph + """ + # edge configs + edge_configs = graph_config.get('edges') + if edge_configs is None: + edge_configs = [] + + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + edge_mapping: dict[str, list[GraphEdge]] = {} + reverse_edge_mapping: dict[str, list[GraphEdge]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get('source') + if not source_node_id: + continue + + if source_node_id not in edge_mapping: + edge_mapping[source_node_id] = [] + + target_node_id = edge_config.get('target') + if not target_node_id: + continue + + if target_node_id not in reverse_edge_mapping: + reverse_edge_mapping[target_node_id] = [] + + # is target node id in source node id edge mapping + if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]): + continue + + target_edge_ids.add(target_node_id) + + # parse run condition + run_condition = None + if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': + run_condition = RunCondition( + type='branch_identify', + branch_identify=edge_config.get('sourceHandle') + ) + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=target_node_id, + run_condition=run_condition + ) + + edge_mapping[source_node_id].append(graph_edge) + reverse_edge_mapping[target_node_id].append(graph_edge) + + # node configs + node_configs = graph_config.get('nodes') + if not node_configs: + raise ValueError("Graph must have at least one node") + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + all_node_id_config_mapping: dict[str, dict] = {} + for node_config in node_configs: + node_id = node_config.get('id') + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + all_node_id_config_mapping[node_id] = node_config + + root_node_ids = [node_config.get('id') for node_config in root_node_configs] + + # fetch root node + if not root_node_id: + # if no root node id, use the START type node as root node + root_node_id = next((node_config.get("id") for node_config in root_node_configs + if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + + if not root_node_id or root_node_id not in root_node_ids: + raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # Check whether it is connected to the previous node + cls._check_connected_to_previous_node( + route=[root_node_id], + edge_mapping=edge_mapping + ) + + # fetch all node ids from root node + node_ids = [root_node_id] + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=root_node_id + ) + + node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} + + # init parallel mapping + parallel_mapping: dict[str, GraphParallel] = {} + node_parallel_mapping: dict[str, str] = {} + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=root_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping + ) + + # Check if it exceeds N layers of parallel + for parallel in parallel_mapping.values(): + if parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=3, + parent_parallel_id=parallel.parent_parallel_id + ) + + # init answer stream generate routes + answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping + ) + + # init end stream param + end_stream_param = EndStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + node_parallel_mapping=node_parallel_mapping + ) + + # init graph + graph = cls( + root_node_id=root_node_id, + node_ids=node_ids, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + answer_stream_generate_routes=answer_stream_generate_routes, + end_stream_param=end_stream_param + ) + + return graph + + def add_extra_edge(self, source_node_id: str, + target_node_id: str, + run_condition: Optional[RunCondition] = None) -> None: + """ + Add extra edge to the graph + + :param source_node_id: source node id + :param target_node_id: target node id + :param run_condition: run condition + """ + if source_node_id not in self.node_ids or target_node_id not in self.node_ids: + return + + if source_node_id not in self.edge_mapping: + self.edge_mapping[source_node_id] = [] + + if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: + return + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=target_node_id, + run_condition=run_condition + ) + + self.edge_mapping[source_node_id].append(graph_edge) + + def get_leaf_node_ids(self) -> list[str]: + """ + Get leaf node ids of the graph + + :return: leaf node ids + """ + leaf_node_ids = [] + for node_id in self.node_ids: + if node_id not in self.edge_mapping: + leaf_node_ids.append(node_id) + elif (len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + leaf_node_ids.append(node_id) + + return leaf_node_ids + + @classmethod + def _recursively_add_node_ids(cls, + node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + node_id: str) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=graph_edge.target_node_id + ) + + @classmethod + def _check_connected_to_previous_node( + cls, + route: list[str], + edge_mapping: dict[str, list[GraphEdge]] + ) -> None: + """ + Check whether it is connected to the previous node + """ + last_node_id = route[-1] + + for graph_edge in edge_mapping.get(last_node_id, []): + if not graph_edge.target_node_id: + continue + + if graph_edge.target_node_id in route: + raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + + new_route = route[:] + new_route.append(graph_edge.target_node_id) + cls._check_connected_to_previous_node( + route=new_route, + edge_mapping=edge_mapping, + ) + + @classmethod + def _recursively_add_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str], + parent_parallel: Optional[GraphParallel] = None + ) -> None: + """ + Recursively add parallel ids + + :param edge_mapping: edge mapping + :param start_node_id: start from node id + :param parallel_mapping: parallel mapping + :param node_parallel_mapping: node parallel mapping + :param parent_parallel: parent parallel + """ + target_node_edges = edge_mapping.get(start_node_id, []) + parallel = None + if len(target_node_edges) > 1: + # fetch all node ids in current parallels + parallel_branch_node_ids = [] + condition_edge_mappings = {} + for graph_edge in target_node_edges: + if graph_edge.run_condition is None: + parallel_branch_node_ids.append(graph_edge.target_node_id) + else: + condition_hash = graph_edge.run_condition.hash + if not condition_hash in condition_edge_mappings: + condition_edge_mappings[condition_hash] = [] + + condition_edge_mappings[condition_hash].append(graph_edge) + + for _, graph_edges in condition_edge_mappings.items(): + if len(graph_edges) > 1: + for graph_edge in graph_edges: + parallel_branch_node_ids.append(graph_edge.target_node_id) + + # any target node id in node_parallel_mapping + if parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None + ) + parallel_mapping[parallel.id] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=parallel_branch_node_ids + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) + or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + for graph_edge in target_node_edges: + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if ( + parent_parallel_parent_parallel + and ( + not parent_parallel_parent_parallel.end_to_node_id + or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) + ) + ): + current_parallel = parent_parallel_parent_parallel + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel + ) + + @classmethod + def _check_exceed_parallel_limit( + cls, + parallel_mapping: dict[str, GraphParallel], + level_limit: int, + parent_parallel_id: str, + current_level: int = 1 + ) -> None: + """ + Check if it exceeds N layers of parallel + """ + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + return + + current_level += 1 + if current_level > level_limit: + raise ValueError(f"Exceeds {level_limit} layers of parallel") + + if parent_parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=level_limit, + parent_parallel_id=parent_parallel.parent_parallel_id, + current_level=current_level + ) + + @classmethod + def _recursively_add_parallel_node_ids(cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str) -> None: + """ + Recursively add node ids + + :param branch_node_ids: in branch node ids + :param edge_mapping: edge mapping + :param merge_node_id: merge node id + :param start_node_id: start node id + """ + for graph_edge in edge_mapping.get(start_node_id, []): + if (graph_edge.target_node_id != merge_node_id + and graph_edge.target_node_id not in branch_node_ids): + branch_node_ids.append(graph_edge.target_node_id) + cls._recursively_add_parallel_node_ids( + branch_node_ids=branch_node_ids, + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=graph_edge.target_node_id + ) + + @classmethod + def _fetch_all_node_ids_in_parallels(cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: + """ + Fetch all node ids in parallels + """ + routes_node_ids: dict[str, list[str]] = {} + for parallel_branch_node_id in parallel_branch_node_ids: + routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] + + # fetch routes node ids + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=parallel_branch_node_id, + routes_node_ids=routes_node_ids[parallel_branch_node_id] + ) + + # fetch leaf node ids from routes node ids + leaf_node_ids: dict[str, list[str]] = {} + merge_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: + if branch_node_id not in leaf_node_ids: + leaf_node_ids[branch_node_id] = [] + + leaf_node_ids[branch_node_id].append(node_id) + + for branch_node_id2, inner_route2 in routes_node_ids.items(): + if ( + branch_node_id != branch_node_id2 + and node_id in inner_route2 + and len(reverse_edge_mapping.get(node_id, [])) > 1 + and cls._is_node_in_routes( + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=node_id, + routes_node_ids=routes_node_ids + ) + ): + if node_id not in merge_branch_node_ids: + merge_branch_node_ids[node_id] = [] + + if branch_node_id2 not in merge_branch_node_ids[node_id]: + merge_branch_node_ids[node_id].append(branch_node_id2) + + # sorted merge_branch_node_ids by branch_node_ids length desc + merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + + duplicate_end_node_ids = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): + if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): + if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids + + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): + # check which node is after + if cls._is_node2_after_node1( + node1_id=node_id, + node2_id=node_id2, + edge_mapping=edge_mapping + ): + if node_id in merge_branch_node_ids: + del merge_branch_node_ids[node_id2] + elif cls._is_node2_after_node1( + node1_id=node_id2, + node2_id=node_id, + edge_mapping=edge_mapping + ): + if node_id2 in merge_branch_node_ids: + del merge_branch_node_ids[node_id] + + branches_merge_node_ids: dict[str, str] = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + if len(branch_node_ids) <= 1: + continue + + for branch_node_id in branch_node_ids: + if branch_node_id in branches_merge_node_ids: + continue + + branches_merge_node_ids[branch_node_id] = node_id + + in_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + in_branch_node_ids[branch_node_id] = [] + if branch_node_id not in branches_merge_node_ids: + # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].append(branch_node_id) + in_branch_node_ids[branch_node_id].extend(node_ids) + else: + merge_node_id = branches_merge_node_ids[branch_node_id] + if merge_node_id != branch_node_id: + in_branch_node_ids[branch_node_id].append(branch_node_id) + + # fetch all node ids from branch_node_id and merge_node_id + cls._recursively_add_parallel_node_ids( + branch_node_ids=in_branch_node_ids[branch_node_id], + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=branch_node_id + ) + + return in_branch_node_ids + + @classmethod + def _recursively_fetch_routes(cls, + edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: list[str]) -> None: + """ + Recursively fetch route + """ + if start_node_id not in edge_mapping: + return + + for graph_edge in edge_mapping[start_node_id]: + # find next node ids + if graph_edge.target_node_id not in routes_node_ids: + routes_node_ids.append(graph_edge.target_node_id) + + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=graph_edge.target_node_id, + routes_node_ids=routes_node_ids + ) + + @classmethod + def _is_node_in_routes(cls, + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: dict[str, list[str]]) -> bool: + """ + Recursively check if the node is in the routes + """ + if start_node_id not in reverse_edge_mapping: + return False + + all_routes_node_ids = set() + parallel_start_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + all_routes_node_ids.add(node_id) + + if branch_node_id in reverse_edge_mapping: + for graph_edge in reverse_edge_mapping[branch_node_id]: + if graph_edge.source_node_id not in parallel_start_node_ids: + parallel_start_node_ids[graph_edge.source_node_id] = [] + + parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) + + parallel_start_node_id = None + for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + if set(branch_node_ids) == set(routes_node_ids.keys()): + parallel_start_node_id = p_start_node_id + return True + + if not parallel_start_node_id: + raise Exception("Parallel start node id not found") + + for graph_edge in reverse_edge_mapping[start_node_id]: + if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: + return False + + return True + + @classmethod + def _is_node2_after_node1( + cls, + node1_id: str, + node2_id: str, + edge_mapping: dict[str, list[GraphEdge]] + ) -> bool: + """ + is node2 after node1 + """ + if node1_id not in edge_mapping: + return False + + for graph_edge in edge_mapping[node1_id]: + if graph_edge.target_node_id == node2_id: + return True + + if cls._is_node2_after_node1( + node1_id=graph_edge.target_node_id, + node2_id=node2_id, + edge_mapping=edge_mapping + ): + return True + + return False \ No newline at end of file diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py new file mode 100644 index 0000000000..1a403f3e49 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from models.workflow import WorkflowType + + +class GraphInitParams(BaseModel): + # init params + tenant_id: str = Field(..., description="tenant / workspace id") + app_id: str = Field(..., description="app id") + workflow_type: WorkflowType = Field(..., description="workflow type") + workflow_id: str = Field(..., description="workflow id") + graph_config: Mapping[str, Any] = Field(..., description="graph config") + user_id: str = Field(..., description="user id") + user_from: UserFrom = Field(..., description="user from, account or end-user") + invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py new file mode 100644 index 0000000000..c7d484ddf5 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -0,0 +1,27 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState + + +class GraphRuntimeState(BaseModel): + variable_pool: VariablePool = Field(..., description="variable pool") + """variable pool""" + + start_at: float = Field(..., description="start time") + """start time""" + total_tokens: int = 0 + """total tokens""" + llm_usage: LLMUsage = LLMUsage.empty_usage() + """llm usage info""" + outputs: dict[str, Any] = {} + """outputs""" + + node_run_steps: int = 0 + """node run steps""" + + node_run_state: RuntimeRouteState = RuntimeRouteState() + """node run state""" diff --git a/api/core/workflow/graph_engine/entities/next_graph_node.py b/api/core/workflow/graph_engine/entities/next_graph_node.py new file mode 100644 index 0000000000..6aa4341ddf --- /dev/null +++ b/api/core/workflow/graph_engine/entities/next_graph_node.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.graph_engine.entities.graph import GraphParallel + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py new file mode 100644 index 0000000000..0362343568 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -0,0 +1,21 @@ +import hashlib +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify like: sourceHandle, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" + + @property + def hash(self) -> str: + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py new file mode 100644 index 0000000000..b5d6e4c09d --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -0,0 +1,111 @@ +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus + + +class RouteNodeState(BaseModel): + class Status(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """node state id""" + + node_id: str + """node id""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.RUNNING + """node status""" + + start_at: datetime + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + index: int = 1 + + def set_finished(self, run_result: NodeRunResult) -> None: + """ + Node finished + + :param run_result: run result + """ + if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + raise Exception(f"Route state {self.id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + self.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + self.status = RouteNodeState.Status.FAILED + self.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + self.node_run_result = run_result + self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + +class RuntimeRouteState(BaseModel): + routes: dict[str, list[str]] = Field( + default_factory=dict, + description="graph state routes (source_node_state_id: target_node_state_id)" + ) + + node_state_mapping: dict[str, RouteNodeState] = Field( + default_factory=dict, + description="node state mapping (route_node_state_id: route_node_state)" + ) + + def create_node_state(self, node_id: str) -> RouteNodeState: + """ + Create node state + + :param node_id: node id + """ + state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + self.node_state_mapping[state.id] = state + return state + + def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + """ + Add route to the graph state + + :param source_node_state_id: source node state id + :param target_node_state_id: target node state id + """ + if source_node_state_id not in self.routes: + self.routes[source_node_state_id] = [] + + self.routes[source_node_state_id].append(target_node_state_id) + + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ + -> list[RouteNodeState]: + """ + Get routes with node state by source node id + + :param source_node_state_id: source node state id + :return: routes with node state + """ + return [self.node_state_mapping[target_state_id] + for target_state_id in self.routes.get(source_node_state_id, [])] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py new file mode 100644 index 0000000000..65d9ab8446 --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -0,0 +1,716 @@ +import logging +import queue +import time +import uuid +from collections.abc import Generator, Mapping +from concurrent.futures import ThreadPoolExecutor, wait +from typing import Any, Optional + +from flask import Flask, current_app + +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeType, + UserFrom, +) +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.event import ( + BaseIterationEvent, + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph, GraphEdge +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.node_mapping import node_classes +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + +logger = logging.getLogger(__name__) + + +class GraphEngineThreadPool(ThreadPoolExecutor): + def __init__(self, max_workers=None, thread_name_prefix='', + initializer=None, initargs=(), max_submit_count=100) -> None: + super().__init__(max_workers, thread_name_prefix, initializer, initargs) + self.max_submit_count = max_submit_count + self.submit_count = 0 + + def submit(self, fn, *args, **kwargs): + self.submit_count += 1 + self.check_is_full() + + return super().submit(fn, *args, **kwargs) + + def check_is_full(self) -> None: + print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") + if self.submit_count > self.max_submit_count: + raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") + + +class GraphEngine: + workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} + + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None + ) -> None: + thread_pool_max_submit_count = 100 + thread_pool_max_workers = 10 + + ## init thread pool + if thread_pool_id: + if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") + + self.thread_pool_id = thread_pool_id + self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] + self.is_main_thread_pool = False + else: + self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool_id = str(uuid.uuid4()) + self.is_main_thread_pool = True + GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool + + self.graph = graph + self.init_params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth + ) + + self.graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter() + ) + + self.max_execution_steps = max_execution_steps + self.max_execution_time = max_execution_time + + def run(self) -> Generator[GraphEngineEvent, None, None]: + # trigger graph run start event + yield GraphRunStartedEvent() + + try: + stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] + if self.init_params.workflow_type == WorkflowType.CHAT: + stream_processor_cls = AnswerStreamProcessor + else: + stream_processor_cls = EndStreamProcessor + + stream_processor = stream_processor_cls( + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool + ) + + # run graph + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id) + ) + + for item in generator: + try: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') + return + elif isinstance(item, NodeRunSucceededEvent): + if item.node_type == NodeType.END: + self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {}) + elif item.node_type == NodeType.ANSWER: + if "answer" not in self.graph_runtime_state.outputs: + self.graph_runtime_state.outputs["answer"] = "" + + self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "") + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() + except Exception as e: + logger.exception(f"Graph run failed: {str(e)}") + yield GraphRunFailedEvent(error=str(e)) + return + + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + except GraphRunFailedError as e: + yield GraphRunFailedEvent(error=e.error) + return + except Exception as e: + logger.exception("Unknown Error when graph running") + yield GraphRunFailedEvent(error=str(e)) + raise e + finally: + if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] + + def _run( + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None + ) -> Generator[GraphEngineEvent, None, None]: + parallel_start_node_id = None + if in_parallel_id: + parallel_start_node_id = start_node_id + + next_node_id = start_node_id + previous_route_node_state: Optional[RouteNodeState] = None + while True: + # max steps reached + if self.graph_runtime_state.node_run_steps > self.max_execution_steps: + raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + + # or max execution time reached + if self._is_timed_out( + start_at=self.graph_runtime_state.start_at, + max_execution_time=self.max_execution_time + ): + raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + + # init route node state + route_node_state = self.graph_runtime_state.node_run_state.create_node_state( + node_id=next_node_id + ) + + # get node config + node_id = route_node_state.node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f'Node {node_id} config not found.') + + # convert to specific node + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + + previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None + + # init workflow run state + node_instance = node_cls( # type: ignore + id=route_node_state.id, + config=node_config, + graph_init_params=self.init_params, + graph=self.graph, + graph_runtime_state=self.graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=self.thread_pool_id + ) + + try: + # run node + generator = self._run_node( + node_instance=node_instance, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + for item in generator: + if isinstance(item, NodeRunStartedEvent): + self.graph_runtime_state.node_run_steps += 1 + item.route_node_state.index = self.graph_runtime_state.node_run_steps + + yield item + + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state + + # append route + if previous_route_node_state: + self.graph_runtime_state.node_run_state.add_route( + source_node_state_id=previous_route_node_state.id, + target_node_state_id=route_node_state.id + ) + except Exception as e: + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = str(e) + yield NodeRunFailedEvent( + error=str(e), + id=node_instance.id, + node_id=next_node_id, + node_type=node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + raise e + + # It may not be necessary, but it is necessary. :) + if (self.graph.node_id_config_mapping[next_node_id] + .get("data", {}).get("type", "").lower() == NodeType.END.value): + break + + previous_route_node_state = route_node_state + + # get next node ids + edge_mappings = self.graph.edge_mapping.get(next_node_id) + if not edge_mappings: + break + + if len(edge_mappings) == 1: + edge = edge_mappings[0] + + if edge.run_condition: + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state + ) + + if not result: + break + + next_node_id = edge.target_node_id + else: + final_node_id = None + + if any(edge.run_condition for edge in edge_mappings): + # if nodes has run conditions, get node id which branch to take based on the run condition results + condition_edge_mappings = {} + for edge in edge_mappings: + if edge.run_condition: + run_condition_hash = edge.run_condition.hash + if run_condition_hash not in condition_edge_mappings: + condition_edge_mappings[run_condition_hash] = [] + + condition_edge_mappings[run_condition_hash].append(edge) + + for _, sub_edge_mappings in condition_edge_mappings.items(): + if len(sub_edge_mappings) == 0: + continue + + edge = sub_edge_mappings[0] + + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + continue + + if len(sub_edge_mappings) == 1: + final_node_id = edge.target_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=sub_edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + break + + if not final_node_id: + break + + next_node_id = final_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + if not final_node_id: + break + + next_node_id = final_node_id + + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: + break + + def _run_parallel_branches( + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent | str, None, None]: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) + if not parallel_id: + node_id = edge_mappings[0].target_node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') + + node_title = node_config.get('data', {}).get('title') + raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # Create a list to store the threads + futures = [] + + # new thread + for edge in edge_mappings: + if ( + edge.target_node_id not in self.graph.node_parallel_mapping + or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + ): + continue + + futures.append( + self.thread_pool.submit(self._run_parallel_node, **{ + 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] + 'q': q, + 'parallel_id': parallel_id, + 'parallel_start_node_id': edge.target_node_id, + 'parent_parallel_id': in_parallel_id, + 'parent_parallel_start_node_id': parallel_start_node_id, + }) + ) + + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue + + # wait all threads + wait(futures) + + # get final node id + final_node_id = parallel.end_to_node_id + if final_node_id: + yield final_node_id + + def _run_parallel_node( + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> None: + """ + Run parallel nodes + """ + with flask_app.app_context(): + try: + q.put(ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + )) + + # run node + generator = self._run( + start_node_id=parallel_start_node_id, + in_parallel_id=parallel_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + for item in generator: + q.put(item) + + # trigger graph run success event + q.put(ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + )) + except GraphRunFailedError as e: + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error + )) + except Exception as e: + logger.exception("Unknown Error when generating in parallel") + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e) + )) + finally: + db.session.remove() + + def _run_node( + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + """ + Run node + """ + # trigger node run start event + yield NodeRunStartedEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + db.session.close() + + try: + # run node + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or 'Unknown error.', + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) + + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + except GenerateTaskStoppedException: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") + raise e + finally: + db.session.close() + + def _append_variables_recursively(self, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + self.graph_runtime_state.variable_pool.add( + [node_id] + variable_key_list, + variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + +class GraphRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 5bae27092f..8cf01727ec 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,9 +1,8 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, @@ -19,24 +18,26 @@ class AnswerNode(BaseNode): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(AnswerNodeData, node_data) # generate routes - generate_routes = self.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) answer = '' for part in generate_routes: - if part.type == "var": + if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = variable_pool.get(value_selector) + value = self.graph_runtime_state.variable_pool.get( + value_selector + ) + if value: answer += value.markdown else: @@ -51,70 +52,16 @@ class AnswerNode(BaseNode): ) @classmethod - def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(AnswerNodeData, node_data) - - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes - - @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -126,6 +73,6 @@ class AnswerNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py new file mode 100644 index 0000000000..6cb80091c9 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -0,0 +1,169 @@ + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} + for answer_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + answer_generate_route[answer_node_id] = generate_route + + # fetch answer dependencies + answer_node_ids = list(answer_generate_route.keys()) + answer_dependencies = cls._fetch_answers_dependencies( + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return AnswerStreamGenerateRoute( + answer_generate_route=answer_generate_route, + answer_dependencies=answer_dependencies + ) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + + @classmethod + def _fetch_answers_dependencies(cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch answer dependencies + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + answer_dependencies: dict[str, list[str]] = {} + for answer_node_id in answer_node_ids: + if answer_dependencies.get(answer_node_id) is None: + answer_dependencies[answer_node_id] = [] + + cls._recursive_fetch_answer_dependencies( + current_node_id=answer_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) + + return answer_dependencies + + @classmethod + def _recursive_fetch_answer_dependencies(cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._recursive_fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py new file mode 100644 index 0000000000..c2a5dd5163 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -0,0 +1,221 @@ +import logging +from collections.abc import Generator +from typing import Optional, cast + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class AnswerStreamProcessor(StreamProcessor): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.generate_routes = graph.answer_stream_generate_routes + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_answer_node_ids + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if (event.route_node_state.node_id != answer_node_id + and (answer_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, route_position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids + + @classmethod + def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file_var = cls._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = cls._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + @classmethod + def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict): + if '__variant' in value and value['__variant'] == FileVar.__name__: + return value + elif isinstance(value, FileVar): + return value.to_dict() + + return None diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py new file mode 100644 index 0000000000..cbabbca37d --- /dev/null +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.graph import Graph + + +class StreamProcessor(ABC): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.rest_node_ids = graph.node_ids.copy() + + @abstractmethod + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + raise NotImplementedError + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if (edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + if node_id not in self.rest_node_ids: + return + + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 9effbbbe67..620c2c426b 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,6 @@ +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - answer: str + answer: str = Field(..., description="answer template string") class GenerateRouteChunk(BaseModel): """ Generate Route Chunk. """ - type: str + + class ChunkType(Enum): + VAR = "var" + TEXT = "text" + + type: ChunkType = Field(..., description="generate route chunk type") class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ - type: str = "var" - value_selector: list[str] + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR + """generate route chunk type""" + value_selector: list[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ - type: str = "text" - text: str + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT + """generate route chunk type""" + text: str = Field(..., description="text") + + +class AnswerNodeDoubleLink(BaseModel): + node_id: str = Field(..., description="node id") + source_node_ids: list[str] = Field(..., description="source node ids") + target_node_ids: list[str] = Field(..., description="target node ids") + + +class AnswerStreamGenerateRoute(BaseModel): + """ + AnswerStreamGenerateRoute entity + """ + answer_dependencies: dict[str, list[str]] = Field( + ..., + description="answer dependencies (answer node id -> dependent answer node ids)" + ) + answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( + ..., + description="answer generate route (answer node id -> generate route chunks)" + ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3d9cf52771..b9912314f1 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,142 +1,103 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from enum import Enum +from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from models import WorkflowNodeExecutionStatus - - -class UserFrom(Enum): - """ - User from - """ - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent, RunEvent class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - tenant_id: str - app_id: str - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - node_id: str - node_data: BaseNodeData - node_run_result: Optional[NodeRunResult] = None - - callbacks: Sequence[WorkflowCallback] - - is_answer_previous_node: bool = False - - def __init__(self, tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, + def __init__(self, + id: str, config: Mapping[str, Any], - callbacks: Sequence[WorkflowCallback] | None = None, - workflow_call_depth: int = 0) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - self.workflow_id = workflow_id - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None) -> None: + self.id = id + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id - # TODO: May need to check if key exists. - self.node_id = config["id"] - if not self.node_id: + node_id = config.get("id") + if not node_id: raise ValueError("Node ID is required.") + self.node_id = node_id self.node_data = self._node_data_cls(**config.get("data", {})) - self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) \ + -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ raise NotImplementedError - def run(self, variable_pool: VariablePool) -> NodeRunResult: + def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node entry - :param variable_pool: variable pool :return: """ - try: - result = self._run( - variable_pool=variable_pool - ) - self.node_run_result = result - return result - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + result = self._run() - def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: - """ - Publish text chunk - :param text: chunk text - :param value_selector: value selector - :return: - """ - if self.callbacks: - for callback in self.callbacks: - callback.on_node_text_chunk( - node_id=self.node_id, - text=text, - metadata={ - "node_type": self.node_type, - "is_answer_previous_node": self.is_answer_previous_node, - "value_selector": value_selector - } - ) + if isinstance(result, NodeRunResult): + yield RunCompletedEvent( + run_result=result + ) + else: + yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict): + def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config :param config: node config :return: """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping(node_data) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + node_id=node_id, + node_data=node_data + ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: BaseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -158,38 +119,3 @@ class BaseNode(ABC): :return: """ return self._node_type - -class BaseIterationNode(BaseNode): - @abstractmethod - def _run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node - :param variable_pool: variable pool - :return: - """ - raise NotImplementedError - - def run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node entry - :param variable_pool: variable pool - :return: - """ - return self._run(variable_pool=variable_pool) - - def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - return self._get_next_iteration(variable_pool, state) - - @abstractmethod - def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - raise NotImplementedError diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 335991ae87..955afdfa1d 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage @@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -33,13 +33,13 @@ class CodeNode(BaseNode): return code_provider.get_default_config() - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run code - :param variable_pool: variable pool :return: """ - node_data = cast(CodeNodeData, self.node_data) + node_data = self.node_data + node_data = cast(CodeNodeData, node_data) # Get code language code_language = node_data.code_language @@ -49,7 +49,7 @@ class CodeNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable] = value # Run code @@ -311,13 +311,19 @@ class CodeNode(BaseNode): return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 440dfa2f27..552914b308 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,10 +11,9 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData _node_type = NodeType.END - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data @@ -24,7 +22,7 @@ class EndNode(BaseNode): outputs = {} for variable_selector in output_variables: - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value return NodeRunResult( @@ -34,52 +32,16 @@ class EndNode(BaseNode): ) @classmethod - def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: - """ - Extract generate nodes - :param graph: graph - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(EndNodeData, node_data) - - return cls.extract_generate_nodes_from_node_data(graph, node_data) - - @classmethod - def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: - """ - Extract generate nodes from node data - :param graph: graph - :param node_data: node data object - :return: - """ - nodes = graph.get('nodes', []) - node_mapping = {node.get('id'): node for node in nodes} - - variable_selectors = node_data.outputs - - generate_nodes = [] - for variable_selector in variable_selectors: - if not variable_selector.value_selector: - continue - - node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_mapping: - node = node_mapping[node_id] - node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': - generate_nodes.append(node_id) - - # remove duplicates - generate_nodes = list(set(generate_nodes)) - - return generate_nodes - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py new file mode 100644 index 0000000000..8390f6d81b --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -0,0 +1,148 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam + + +class EndStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str] + ) -> EndStreamParam: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selector of end nodes + end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} + for end_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.END.value: + continue + + # skip end node in parallel + if end_node_id in node_parallel_mapping: + continue + + # get generate route for stream output + stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) + end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors + + # fetch end dependencies + end_node_ids = list(end_stream_variable_selectors_mapping.keys()) + end_dependencies = cls._fetch_ends_dependencies( + end_node_ids=end_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return EndStreamParam( + end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, + end_dependencies=end_dependencies + ) + + @classmethod + def extract_stream_variable_selector_from_node_data(cls, + node_id_config_mapping: dict[str, dict], + node_data: EndNodeData) -> list[list[str]]: + """ + Extract stream variable selector from node data + :param node_id_config_mapping: node id config mapping + :param node_data: node data object + :return: + """ + variable_selectors = node_data.outputs + + value_selectors = [] + for variable_selector in variable_selectors: + if not variable_selector.value_selector: + continue + + node_id = variable_selector.value_selector[0] + if node_id != 'sys' and node_id in node_id_config_mapping: + node = node_id_config_mapping[node_id] + node_type = node.get('data', {}).get('type') + if ( + variable_selector.value_selector not in value_selectors + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == 'text' + ): + value_selectors.append(variable_selector.value_selector) + + return value_selectors + + @classmethod + def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ + -> list[list[str]]: + """ + Extract stream variable selector from node config + :param node_id_config_mapping: node id config mapping + :param config: node config + :return: + """ + node_data = EndNodeData(**config.get("data", {})) + return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) + + @classmethod + def _fetch_ends_dependencies(cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch end dependencies + :param end_node_ids: end node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + end_dependencies: dict[str, list[str]] = {} + for end_node_id in end_node_ids: + if end_dependencies.get(end_node_id) is None: + end_dependencies[end_node_id] = [] + + cls._recursive_fetch_end_dependencies( + current_node_id=end_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) + + return end_dependencies + + @classmethod + def _recursive_fetch_end_dependencies(cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch end dependencies + :param current_node_id: current node id + :param end_node_id: end node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param end_dependencies: end dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + end_dependencies[end_node_id].append(source_node_id) + else: + cls._recursive_fetch_end_dependencies( + current_node_id=source_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py new file mode 100644 index 0000000000..4474c2a78a --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -0,0 +1,191 @@ +import logging +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +class EndStreamProcessor(StreamProcessor): + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.end_stream_param = graph.end_stream_param + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + self.has_outputed = False + self.outputed_node_ids = set() + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_end_node_ids + + if stream_out_end_node_ids: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = '\n' + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[end_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for end_node_id, position in self.route_position.items(): + # all depends on end node id not in rest node ids + if (event.route_node_state.node_id != end_node_id + and (end_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + continue + + route_position = self.route_position[end_node_id] + + position = 0 + value_selectors = [] + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position >= route_position: + value_selectors.append(current_value_selectors) + + position += 1 + + for value_selector in value_selectors: + if not value_selector: + continue + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + current_node_id = value_selector[0] + if self.has_outputed and current_node_id not in self.outputed_node_ids: + text = '\n' + text + + self.outputed_node_ids.add(current_node_id) + self.has_outputed = True + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[end_node_id] += 1 + + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_end_node_ids = [] + for end_node_id, route_position in self.route_position.items(): + if end_node_id not in self.rest_node_ids: + continue + + # all depends on end node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): + continue + + position = 0 + value_selector = None + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position == route_position: + value_selector = current_value_selectors + break + + position += 1 + + if not value_selector: + continue + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_end_node_ids.append(end_node_id) + + return stream_out_end_node_ids diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index ad4fc8f04f..a0edf7b579 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel, Field + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData): END Node Data. """ outputs: list[VariableSelector] + + +class EndStreamParam(BaseModel): + """ + EndStreamParam entity + """ + end_dependencies: dict[str, list[str]] = Field( + ..., + description="end dependencies (end node id -> dependent node ids)" + ) + end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( + ..., + description="end stream variable selector mapping (end node id -> stream variable selectors)" + ) diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event.py new file mode 100644 index 0000000000..276c13a6d4 --- /dev/null +++ b/api/core/workflow/nodes/event.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 037a7a1848..3f68c8b1d0 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,15 +1,14 @@ import logging +from collections.abc import Mapping, Sequence from mimetypes import guess_extension from os import path -from typing import cast +from typing import Any, cast from configs import dify_config from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( HttpRequestNodeData, @@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode): }, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text + node_data.authorization.config.api_key = parser.convert_template( + template=node_data.authorization.config.api_key, + variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None try: http_executor = HttpExecutor( - node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool + node_data=node_data, + timeout=self._get_request_timeout(node_data), + variable_pool=self.graph_runtime_state.variable_pool ) # invoke http executor @@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode): return timeout @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = cast(HttpRequestNodeData, node_data) try: http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) @@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 7eb69b80df..338277ace1 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -3,20 +3,7 @@ from typing import Literal, Optional from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match", - # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" - ] - value: Optional[str] = None +from core.workflow.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 2b253764b7..ca87eecd0d 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,13 +1,10 @@ -import re -from collections.abc import Sequence -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData -from core.workflow.utils.variable_template_parser import VariableTemplateParser +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus @@ -15,31 +12,35 @@ class IfElseNode(BaseNode): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs = { + node_inputs: dict[str, list] = { "conditions": [] } - process_datas = { + process_datas: dict[str, list] = { "condition_results": [] } input_conditions = [] final_result = False selected_case_id = None + condition_processor = ConditionProcessor() try: # Check if the new cases structure is used if node_data.cases: for case in node_data.cases: - input_conditions, group_result = self.process_conditions(variable_pool, case.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions + ) + # Apply the logical operator for the current case final_result = all(group_result) if case.logical_operator == "and" else any(group_result) @@ -58,7 +59,10 @@ class IfElseNode(BaseNode): else: # Fallback to old structure if cases are not defined - input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=node_data.conditions + ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) @@ -94,376 +98,17 @@ class IfElseNode(BaseNode): return data - def evaluate_condition( - self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator - - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - elif comparison_operator == "regex match": - return self._assert_regex_match(actual_value, expected_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") - - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - for condition in conditions: - actual_variable = variable_pool.get_any(condition.variable_selector) - - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - expected_value = variable_template_parser.extract_variable_selectors() - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - else: - expected_value = condition.value - else: - expected_value = None - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_variable, - "expected_value": expected_value, - "comparison_operator": comparison_operator - } - ) - - result = self.evaluate_condition(actual_variable, expected_value, comparison_operator) - group_result.append(result) - - return input_conditions, group_result - - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value not in actual_value: - return False - return True - - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True - return False - - def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if actual_value is None: - return False - return re.search(expected_value, actual_value) is not None - - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True - return False - - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True - return False - - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True - return False - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 177b47b951..5fc5a827ae 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): @@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData): iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + +class IterationStartNodeData(BaseNodeData): + """ + Iteration Start Node Data. + """ + pass + class IterationState(BaseIterationState): """ Iteration State. diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 54dfe8b7f4..93eff16c33 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,124 +1,371 @@ -from typing import cast +import logging +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime, timezone +from typing import Any, cast +from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode -from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.graph_engine.entities.event import ( + BaseGraphEvent, + BaseNodeEvent, + BaseParallelBranchEvent, + GraphRunFailedEvent, + InNodeEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) -class IterationNode(BaseIterationNode): + +class IterationNode(BaseNode): """ Iteration Node. """ _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run the node. """ self.node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(self.node_data.iterator_selector) + iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - if not isinstance(iterator, list): - raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") - - state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ - 'iterator_selector': iterator - }, outputs=[], metadata=IterationState.MetaData( - iterator_length=len(iterator) if iterator is not None else 0 - )) + if not iterator_list_segment: + raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - self._set_current_iteration_variable(variable_pool, state) - return state + iterator_list_value = iterator_list_segment.to_object() - def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - # resolve current output - self._resolve_current_output(variable_pool, state) - # move to next iteration - self._next_iteration(variable_pool, state) + if not isinstance(iterator_list_value, list): + raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - node_data = cast(IterationNodeData, self.node_data) - if self._reached_iteration_limit(variable_pool, state): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs = { + "iterator_selector": iterator_list_value + } + + graph_config = self.graph_config + + if not self.node_data.start_node_id: + raise ValueError(f'field start_node_id in iteration {self.node_id} not found') + + root_node_id = self.node_data.start_node_id + + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=root_node_id + ) + + if not iteration_graph: + raise ValueError('iteration graph not found') + + leaf_node_ids = iteration_graph.get_leaf_node_ids() + iteration_leaf_node_ids = [] + for leaf_node_id in leaf_node_ids: + node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) + if not node_config: + continue + + leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") + if not leaf_node_iteration_id: + continue + + if leaf_node_iteration_id != self.node_id: + continue + + iteration_leaf_node_ids.append(leaf_node_id) + + # add condition of end nodes to root node + iteration_graph.add_extra_edge( + source_node_id=leaf_node_id, + target_node_id=root_node_id, + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=[self.node_id, "index"], + comparison_operator="<", + value=str(len(iterator_list_value)) + ) + ] + ) + ) + + variable_pool = self.graph_runtime_state.variable_pool + + # append iteration variable (item, index) to variable pool + variable_pool.add( + [self.node_id, 'index'], + 0 + ) + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[0] + ) + + # init graph engine + from core.workflow.graph_engine.graph_engine import GraphEngine + graph_engine = GraphEngine( + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_type=self.workflow_type, + workflow_id=self.workflow_id, + user_id=self.user_id, + user_from=self.user_from, + invoke_from=self.invoke_from, + call_depth=self.workflow_call_depth, + graph=iteration_graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + ) + + start_at = datetime.now(timezone.utc).replace(tzinfo=None) + + yield IterationRunStartedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + metadata={ + "iterator_length": len(iterator_list_value) + }, + predecessor_node_id=self.previous_node_id + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=0, + pre_iteration_output=None + ) + + outputs: list[Any] = [] + try: + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: + continue + + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) + event.route_node_state.node_run_result.metadata = metadata + + yield event + + # handle iteration run result + if event.route_node_state.node_id in iteration_leaf_node_ids: + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, 'index']) + if current_index is None: + raise ValueError(f'iteration {self.node_id} current index not found') + + next_index = int(current_index.to_object()) + 1 + variable_pool.add( + [self.node_id, 'index'], + next_index + ) + + if next_index < len(iterator_list_value): + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[next_index] + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder( + current_iteration_output) if current_iteration_output else None + ) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={ + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens + }, + error=event.error, + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + break + else: + event = cast(InNodeEvent, event) + yield event + + yield IterationRunSucceededEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, outputs={ - 'output': jsonable_encoder(state.outputs) + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens } ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'output': jsonable_encoder(outputs) + } + ) + ) + except Exception as e: + # iteration run failed + logger.exception("Iteration run failed") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={ + "output": jsonable_encoder(outputs) + }, + steps=len(iterator_list_value), + metadata={ + "total_tokens": graph_engine.graph_runtime_state.total_tokens + }, + error=str(e), + ) - return node_data.start_node_id + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + finally: + # remove iteration variable (item, index) from variable pool after iteration run completed + variable_pool.remove([self.node_id, 'index']) + variable_pool.remove([self.node_id, 'item']) - def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): - """ - Set current iteration variable. - :variable_pool: variable pool - """ - node_data = cast(IterationNodeData, self.node_data) - - variable_pool.add((self.node_id, 'index'), state.index) - # get the iterator value - iterator = variable_pool.get_any(node_data.iterator_selector) - - if iterator is None or not isinstance(iterator, list): - return - - if state.index < len(iterator): - variable_pool.add((self.node_id, 'item'), iterator[state.index]) - - def _next_iteration(self, variable_pool: VariablePool, state: IterationState): - """ - Move to next iteration. - :param variable_pool: variable pool - """ - state.index += 1 - self._set_current_iteration_variable(variable_pool, state) - - def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): - """ - Check if iteration limit is reached. - :return: True if iteration limit is reached, False otherwise - """ - node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(node_data.iterator_selector) - - if iterator is None or not isinstance(iterator, list): - return True - - return state.index >= len(iterator) - - def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): - """ - Resolve current output. - :param variable_pool: variable pool - """ - output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_any(output_selector) - # clear the output for this iteration - variable_pool.remove([self.node_id] + output_selector[1:]) - state.current_output = output - if output is not None: - # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). - if isinstance(output, list): - state.outputs.extend(output) - else: - state.outputs.append(output) - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - 'input_selector': node_data.iterator_selector, - } \ No newline at end of file + variable_mapping = { + f'{node_id}.input_selector': node_data.iterator_selector, + } + + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=node_data.start_node_id + ) + + if not iteration_graph: + raise ValueError('iteration graph not found') + + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): + if sub_node_config.get('data', {}).get('iteration_id') != node_id: + continue + + # variable selector to variable mapping + try: + # Get node class + from core.workflow.nodes.node_mapping import node_classes + node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + continue + + node_cls = cast(BaseNode, node_cls) + + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + config=sub_node_config + ) + sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + except NotImplementedError: + sub_node_variable_mapping = {} + + # remove iteration variables + sub_node_variable_mapping = { + sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + if value[0] != node_id + } + + variable_mapping.update(sub_node_variable_mapping) + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() + if value[0] not in iteration_graph.node_ids + } + + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py new file mode 100644 index 0000000000..25044cf3eb --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationStartNode(BaseNode): + """ + Iteration Start Node. + """ + _node_data_cls = IterationStartNodeData + _node_type = NodeType.ITERATION_START + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 6c052c0d6b..2d1ac4731c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,3 +1,5 @@ +import logging +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import func @@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) + default_retrieval_model = { 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, @@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData node_type = NodeType.KNOWLEDGE_RETRIEVAL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + def _run(self) -> NodeRunResult: + node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = variable_pool.get_any(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable variables = { 'query': query @@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode): ) except Exception as e: - + logger.exception("Error when running knowledge retrieval node") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode): return context_list @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {} - variable_mapping['query'] = node_data.query_variable_selector + variable_mapping[node_id + '.query'] = node_data.query_variable_selector return variable_mapping def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 737b1af143..f26ec1b0b5 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,16 +1,17 @@ import json -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast + +from pydantic import BaseModel from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, @@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -43,17 +46,26 @@ if TYPE_CHECKING: +class ModelInvokeCompleted(BaseModel): + """ + Model invoke completed + """ + text: str + usage: LLMUsage + finish_reason: Optional[str] = None + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ node_data = cast(LLMNodeData, deepcopy(self.node_data)) + variable_pool = self.graph_runtime_state.variable_pool node_inputs = None process_data = None @@ -80,10 +92,15 @@ class LLMNode(BaseNode): node_inputs['#files#'] = [file.to_dict() for file in files] # fetch context value - context = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data, variable_pool) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield event if context: - node_inputs['#context#'] = context + node_inputs['#context#'] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -115,19 +132,34 @@ class LLMNode(BaseNode): } # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) ) + return outputs = { 'text': result_text, @@ -135,22 +167,26 @@ class LLMNode(BaseNode): 'finish_reason': finish_reason } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + }, + llm_usage=usage + ) ) def _invoke_llm(self, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: + stop: Optional[list[str]] = None) \ + -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -170,23 +206,31 @@ class LLMNode(BaseNode): ) # handle invoke result - text, usage, finish_reason = self._handle_invoke_result( + generator = self._handle_invoke_result( invoke_result=invoke_result ) + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompleted): + usage = event.usage + # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage, finish_reason - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ + -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result :return: """ + if isinstance(invoke_result, LLMResult): + return + model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = '' usage = None finish_reason = None @@ -194,7 +238,10 @@ class LLMNode(BaseNode): text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + yield RunStreamChunkEvent( + chunk_content=text, + from_variable_selector=[self.node_id, 'text'] + ) if not model: model = result.model @@ -211,11 +258,15 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - return full_text, usage, finish_reason + yield ModelInvokeCompleted( + text=full_text, + usage=usage, + finish_reason=finish_reason + ) def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -224,13 +275,13 @@ class LLMNode(BaseNode): """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2': + if messages.edition_type == 'jinja2' and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2': + if message.edition_type == 'jinja2' and message.jinja2_text: message.text = message.jinja2_text return messages @@ -348,7 +399,7 @@ class LLMNode(BaseNode): return files - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: """ Fetch context :param node_data: node data @@ -356,15 +407,18 @@ class LLMNode(BaseNode): :return: """ if not node_data.context.enabled: - return None + return if not node_data.context.variable_selector: - return None + return context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - return context_value + yield RunRetrieverResourceEvent( + retriever_resources=[], + context=context_value + ) elif isinstance(context_value, list): context_str = '' original_retriever_resource = [] @@ -381,17 +435,10 @@ class LLMNode(BaseNode): if retriever_resource: original_retriever_resource.append(retriever_resource) - if self.callbacks and original_retriever_resource: - for callback in self.callbacks: - callback.on_event( - event=QueueRetrieverResourcesEvent( - retriever_resources=original_retriever_resource - ) - ) - - return context_str.strip() - - return None + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, + context=context_str.strip() + ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: """ @@ -574,7 +621,8 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): + if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( + content_item, ImagePromptMessageContent): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -646,13 +694,19 @@ class LLMNode(BaseNode): db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - prompt_template = node_data.prompt_template variable_selectors = [] @@ -702,6 +756,10 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 7d53c6f5f2..526404e30d 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,20 +1,34 @@ -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode +from typing import Any + +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.loop.entities import LoopNodeData, LoopState +from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseIterationNode): +class LoopNode(BaseNode): """ Loop Node. """ _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self, variable_pool: VariablePool) -> LoopState: - return super()._run(variable_pool) + def _run(self) -> LoopState: + return super()._run() - def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ - Get next iteration start node id based on the graph. + Get conditions. """ + node_id = node_config.get('id') + if not node_id: + return [] + + # TODO waiting for implementation + return [Condition( + variable_selector=[node_id, 'index'], + comparison_operator="≤", + value_type="value_selector", + value_selector=[] + )] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py new file mode 100644 index 0000000000..b98525e86e --- /dev/null +++ b/api/core/workflow/nodes/node_mapping.py @@ -0,0 +1,37 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.ITERATION_START: IterationStartNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, +} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2876695a82..2e65705f10 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,6 +1,7 @@ import json import uuid -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode): } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = variable_pool.get_any(node_data.query) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) if not variable: raise ValueError("Input variable content not found or is empty") query = variable @@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode): raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ and node_data.reasoning_mode == 'function_call': # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, variable_pool, model_config, memory + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, + prompt_messages = self._generate_prompt_engineering_prompt(node_data, + query, + self.graph_runtime_state.variable_pool, + model_config, memory) prompt_message_tools = [] @@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) def _invoke_llm(self, node_data_model: ModelConfig, @@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode): return self._model_instance, self._model_config @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - variable_mapping = { 'query': node_data.query } @@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode): for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index f4057d50f3..ecab8db9b6 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,12 @@ import json import logging -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder @@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, @@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData node_type = NodeType.QUESTION_CLASSIFIER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data = cast(QuestionClassifierNodeData, node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables variable = variable_pool.get(node_data.query_variable_selector) @@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode): ) # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + category_name = node_data.classes[0].name category_id = node_data.classes[0].id try: @@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) except ValueError as e: @@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {'query': node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: @@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode): variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 54e66bd671..69cdec6a92 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,9 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData +from collections.abc import Mapping, Sequence +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -11,14 +13,13 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData _node_type = NodeType.START - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_inputs = dict(variable_pool.user_inputs) - system_inputs = variable_pool.system_variables + node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + system_inputs = self.graph_runtime_state.variable_pool.system_variables for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] @@ -30,9 +31,16 @@ class StartNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 21f71db6c5..b14a394a0a 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,15 +1,16 @@ import os -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) + class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode): } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node """ @@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable_name = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable_name] = value # Run code try: @@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, error=str(e) ) - + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, @@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode): 'output': result['result'] } ) - + @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } \ No newline at end of file + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ccce9ef360..feedeb6dad 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -26,7 +26,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the tool node """ @@ -56,8 +56,8 @@ class ToolNode(BaseNode): # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) + parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) try: messages = ToolEngine.workflow_invoke( @@ -66,6 +66,7 @@ class ToolNode(BaseNode): user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, ) except Exception as e: return NodeRunResult( @@ -145,7 +146,8 @@ class ToolNode(BaseNode): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): + def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ + -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -221,9 +223,16 @@ class ToolNode(BaseNode): return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -239,4 +248,8 @@ class ToolNode(BaseNode): elif input.type == 'constant': pass + result = { + node_id + '.' + key: value for key, value in result.items() + } + return result diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 885f7d7617..6944d9e82d 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} @@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode): if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: for selector in node_data.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs = { "output": variable @@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode): else: for group in node_data.advanced_settings.groups: for selector in group.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs[group.group_name] = { @@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ return {} diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 8c2adcabb9..b2f32c6aaa 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from core.app.segments import SegmentType, Variable, factory from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from extensions.ext_database import db from models import ConversationVariable, WorkflowNodeExecutionStatus @@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: data = cast(VariableAssignerData, self.node_data) # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableAssignerNodeError('assigned variable not found') match data.write_mode: case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_variable = original_variable.model_copy(update={'value': income_value.value}) case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: raise VariableAssignerNodeError('input value not found') updated_value = original_variable.value + [income_value.value] @@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode): raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) if not conversation_id: raise VariableAssignerNodeError('conversation_id not found') update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) diff --git a/api/core/workflow/utils/condition/__init__.py b/api/core/workflow/utils/condition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py new file mode 100644 index 0000000000..e195730a31 --- /dev/null +++ b/api/core/workflow/utils/condition/entities.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py new file mode 100644 index 0000000000..5ff61aab3d --- /dev/null +++ b/api/core/workflow/utils/condition/processor.py @@ -0,0 +1,383 @@ +from collections.abc import Sequence +from typing import Any, Optional + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.condition.entities import Condition +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class ConditionProcessor: + def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): + input_conditions = [] + group_result = [] + + index = 0 + for condition in conditions: + index += 1 + actual_value = variable_pool.get_any( + condition.variable_selector + ) + + expected_value = None + if condition.value is not None: + variable_template_parser = VariableTemplateParser(template=condition.value) + variable_selectors = variable_template_parser.extract_variable_selectors() + if variable_selectors: + for variable_selector in variable_selectors: + value = variable_pool.get_any( + variable_selector.value_selector + ) + expected_value = variable_template_parser.format({variable_selector.variable: value}) + + if expected_value is None: + expected_value = condition.value + else: + expected_value = condition.value + + comparison_operator = condition.comparison_operator + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": comparison_operator + } + ) + + result = self.evaluate_condition(actual_value, comparison_operator, expected_value) + group_result.append(result) + + return input_conditions, group_result + + def evaluate_condition( + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None + ) -> bool: + """ + Evaluate condition + :param actual_value: actual value + :param expected_value: expected value + :param comparison_operator: comparison operator + + :return: bool + """ + if comparison_operator == "contains": + return self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + return self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + return self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + return self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + return self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + return self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + return self._assert_empty(actual_value) + elif comparison_operator == "not empty": + return self._assert_not_empty(actual_value) + elif comparison_operator == "=": + return self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + return self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + return self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + return self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + return self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + return self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + return self._assert_null(actual_value) + elif comparison_operator == "not null": + return self._assert_not_null(actual_value) + else: + raise ValueError(f"Invalid comparison operator: {comparison_operator}") + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + +class ConditionAssertionError(Exception): + def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: + self.message = message + self.conditions = conditions + self.sub_condition_compare_results = sub_condition_compare_results + super().__init__(self.message) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3157eedfee..e69de29bb2 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,1005 +0,0 @@ -import logging -import time -from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast - -import contexts -from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.entities import IterationState -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode -from core.workflow.nodes.variable_assigner import VariableAssignerNode -from extensions.ext_database import db -from models.workflow import ( - Workflow, - WorkflowNodeExecutionStatus, -) - -node_classes: Mapping[NodeType, type[BaseNode]] = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, - NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, -} - -logger = logging.getLogger(__name__) - - -class WorkflowEngineManager: - def get_default_configs(self) -> list[dict]: - """ - Get default block configs - """ - default_block_configs = [] - for node_type, node_class in node_classes.items(): - default_config = node_class.get_default_config() - if default_config: - default_block_configs.append(default_config) - - return default_block_configs - - def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: - """ - Get default config of node. - :param node_type: node type - :param filters: filter by node config parameters. - :return: - """ - node_class = node_classes.get(node_type) - if not node_class: - return None - - default_config = node_class.get_default_config(filters=filters) - if not default_config: - return None - - return default_config - - def run_workflow( - self, - *, - workflow: Workflow, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - callbacks: Sequence[WorkflowCallback], - call_depth: int = 0, - variable_pool: VariablePool | None = None, - ) -> None: - """ - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param invoke_from: invoke from - :param callbacks: workflow callbacks - :param call_depth: call depth - :param variable_pool: variable pool - """ - # fetch workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - if 'nodes' not in graph or 'edges' not in graph: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - - - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH - if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - - # init workflow run state - if not variable_pool: - variable_pool = contexts.workflow_variable_pool.get() - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - workflow_call_depth=call_depth - ) - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - ) - - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: - """ - Run workflow - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :param call_depth: call depth - :param start_at: force specific start node - :param end_at: force specific end node - :return: - """ - graph = workflow.graph_dict - - try: - answer_prov_node_ids = [] - for node in graph.get('nodes', []): - if node.get('id', '') == 'answer': - try: - answer_prov_node_ids.append(node.get('data', {}) - .get('answer', '') - .replace('#', '') - .replace('.text', '') - .replace('{{', '') - .replace('}}', '').split('.')[0]) - except Exception as e: - logger.error(e) - - predecessor_node: BaseNode | None = None - current_iteration_node: BaseIterationNode | None = None - has_entry_node = False - max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS - max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME - while True: - # get next node, multiple target nodes in the future - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=predecessor_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - - if not next_node: - # reached loop/iteration end or overall end - if current_iteration_node and workflow_run_state.current_iteration_state: - # reached loop/iteration end - # get next iteration - next_iteration = current_iteration_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_iteration, NodeRunResult): - if next_iteration.outputs: - for variable_key, variable_value in next_iteration.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=current_iteration_node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - # iteration has ended - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=current_iteration_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - # continue overall process - elif isinstance(next_iteration, str): - # move to next iteration - next_node_id = next_iteration - # get next id - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if not next_node: - break - - # check is already ran - if self._check_node_has_ran(workflow_run_state, next_node.node_id): - predecessor_node = next_node - continue - - has_entry_node = True - - # max steps reached - if workflow_run_state.workflow_node_steps > max_execution_steps: - raise ValueError('Max steps {} reached.'.format(max_execution_steps)) - - # or max execution time reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) - - # handle iteration nodes - if isinstance(next_node, BaseIterationNode): - current_iteration_node = next_node - workflow_run_state.current_iteration_state = next_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - callbacks=callbacks - ) - predecessor_node = next_node - # move to start node of iteration - next_node_id = next_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - continue - else: - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if next_node and next_node.node_id in answer_prov_node_ids: - next_node.is_answer_previous_node = True - - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=next_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) - - if next_node.node_type in [NodeType.END]: - break - - predecessor_node = next_node - - if not has_entry_node: - self._workflow_run_failed( - error='Start node not found in workflow graph.', - callbacks=callbacks - ) - return - except GenerateTaskStoppedException as e: - return - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - return - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - def single_step_run_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: - """ - Single step run workflow node - :param workflow: Workflow instance - :param node_id: node id - :param user_id: user id - :param user_inputs: user inputs - :return: - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - # fetch node config from node id - node_config = None - for node in nodes: - if node.get('id') == node_id: - node_config = node - break - - if not node_config: - raise ValueError('node id not found in workflow graph') - - # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) - - # init workflow run state - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - workflow_call_depth=0 - ) - - try: - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - if node_cls is None: - raise ValueError('Node class not found') - # variable selector to variable mapping - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # run node - node_run_result = node_instance.run( - variable_pool=variable_pool - ) - - # sign output files - node_run_result.outputs = self.handle_special_values(node_run_result.outputs) - except Exception as e: - raise WorkflowNodeRunFailedError( - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_title=node_instance.node_data.title, - error=str(e) - ) - - return node_instance, node_run_result - - def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[WorkflowCallback], - ) -> None: - """ - Single iteration run workflow node - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - for node in nodes: - if node.get('id') == node_id: - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]: - node_config = node - else: - raise ValueError('node id is not an iteration node') - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - # variable selector to variable mapping - iteration_nested_nodes = [ - node for node in nodes - if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id - ] - iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - for node_config in iteration_nested_nodes: - # mapping user inputs to variable pool - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - if node_cls is None: - raise ValueError('Node class not found') - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - # remove iteration variables - variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() - if value[0] != node_id - } - - # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_nested_node_ids - } - - # append variables to variable pool - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - callbacks=callbacks, - workflow_call_depth=0 - ) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # fetch end node of iteration - end_node_id = None - for edge in graph.get('edges'): - if edge.get('source') == node_id: - end_node_id = edge.get('target') - break - - if not end_node_id: - raise ValueError('end node of iteration not found') - - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - workflow_call_depth=0 - ) - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - start_at=node_id, - end_at=end_node_id - ) - - def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run success - :param callbacks: workflow callbacks - :return: - """ - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_succeeded() - - def _workflow_run_failed(self, error: str, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run failed - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - if callbacks: - for callback in callbacks: - callback.on_workflow_run_failed( - error=error - ) - - def _workflow_iteration_started(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - predecessor_node_id: Optional[str] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration started - :param current_iteration_node: current iteration node - :param workflow_run_state: workflow run state - :param callbacks: workflow callbacks - :return: - """ - # get nested nodes - iteration_nested_nodes = [ - node for node in graph.get('nodes') - if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id - ] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_started( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - node_data=current_iteration_node.node_data, - inputs=workflow_run_state.current_iteration_state.inputs, - predecessor_node_id=predecessor_node_id, - metadata=workflow_run_state.current_iteration_state.metadata.model_dump() - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - def _workflow_iteration_next(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration next - :param workflow_run_state: workflow run state - :return: - """ - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_next( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - index=workflow_run_state.current_iteration_state.index, - node_run_index=workflow_run_state.workflow_node_steps, - output=workflow_run_state.current_iteration_state.get_current_output() - ) - # clear ran nodes - workflow_run_state.workflow_node_runs = [ - node_run for node_run in workflow_run_state.workflow_node_runs - if node_run.iteration_node_id != current_iteration_node.node_id - ] - - # clear variables in current iteration - nodes = graph.get('nodes') - nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] - - for node in nodes: - workflow_run_state.variable_pool.remove((node.get('id'),)) - - def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_completed( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - outputs={ - 'output': workflow_run_state.current_iteration_state.outputs - } - ) - - def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> Optional[BaseNode]: - """ - Get next node - multiple target nodes in the future. - :param graph: workflow graph - :param predecessor_node: predecessor node - :param callbacks: workflow callbacks - :return: - """ - nodes = graph.get('nodes') - if not nodes: - return None - - if not predecessor_node: - for node_config in nodes: - node_cls = None - if start_at: - if node_config.get('id') == start_at: - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - else: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - node_cls = StartNode - if node_cls: - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - else: - edges = graph.get('edges') - source_node_id = predecessor_node.node_id - - # fetch all outgoing edges from source node - outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] - if not outgoing_edges: - return None - - # fetch target node id from outgoing edges - outgoing_edge = None - source_handle = predecessor_node.node_run_result.edge_source_handle \ - if predecessor_node.node_run_result else None - if source_handle: - for edge in outgoing_edges: - if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - outgoing_edge = edge - break - else: - outgoing_edge = outgoing_edges[0] - - if not outgoing_edge: - return None - - target_node_id = outgoing_edge.get('target') - - if end_at and target_node_id == end_at: - return None - - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break - - if not target_node_config: - return None - - # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) - - return target_node( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - node_id: str, - callbacks: Sequence[WorkflowCallback]): - """ - Get node from graph by node id - """ - nodes = graph.get('nodes') - if not nodes: - return None - - for node_config in nodes: - if node_config.get('id') == node_id: - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes[node_type] - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: - """ - Check timeout - :param start_at: start time - :param max_execution_time: max execution time - :return: - """ - return time.perf_counter() - start_at > max_execution_time - - def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: - """ - Check node has ran - """ - return bool([ - node_and_result for node_and_result in workflow_run_state.workflow_node_runs - if node_and_result.node_id == node_id - ]) - - def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - node_run_index=workflow_run_state.workflow_node_steps, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None - ) - - db.session.close() - - workflow_nodes_and_result = WorkflowNodeAndResult( - node=node, - result=None - ) - - # add to workflow_nodes_and_results - workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - # mark node as running - if workflow_run_state.current_iteration_state: - workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( - node_id=node.node_id, - iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id - )) - - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) - except GenerateTaskStoppedException as e: - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error='Workflow stopped.' - ) - except Exception as e: - logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - # node run failed - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_failed( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - error=node_run_result.error, - inputs=node_run_result.inputs, - outputs=node_run_result.outputs, - process_data=node_run_result.process_data, - ) - - raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - - if node.is_answer_previous_node and not isinstance(node, LLMNode): - if not node_run_result.metadata: - node_run_result.metadata = {} - node_run_result.metadata["is_answer_previous_node"]=True - workflow_nodes_and_result.result = node_run_result - - # node run success - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_succeeded( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - inputs=node_run_result.inputs, - process_data=node_run_result.process_data, - outputs=node_run_result.outputs, - execution_metadata=node_run_result.metadata - ) - - if node_run_result.outputs: - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - - if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - db.session.close() - - def _append_variables_recursively(self, variable_pool: VariablePool, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): - """ - Append variables recursively - :param variable_pool: variable pool - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add( - [node_id] + variable_key_list, variable_value - ) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value - ) - - @classmethod - def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None - - new_value = value.copy() - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value - - def _mapping_user_inputs_to_variable_pool(self, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, - variable_pool: VariablePool, - tenant_id: str, - node_instance: BaseNode): - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # FIXME: temp fix for image type - if node_instance.node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.add([variable_node_id]+variable_key_list, value) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py new file mode 100644 index 0000000000..a359bd606e --- /dev/null +++ b/api/core/workflow/workflow_entry.py @@ -0,0 +1,314 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional, cast + +from configs import dify_config +from core.app.app_config.entities import FileExtraConfig +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunEvent +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.node_mapping import node_classes +from models.workflow import ( + Workflow, + WorkflowType, +) + +logger = logging.getLogger(__name__) + + +class WorkflowEntry: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None + ) -> None: + """ + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph + :param user_id: user id + :param user_from: user from + :param invoke_from: invoke from + :param call_depth: call depth + :param variable_pool: variable pool + :param thread_pool_id: thread pool id + """ + # check call depth + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + + # init workflow run state + self.graph_engine = GraphEngine( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=thread_pool_id + ) + + def run( + self, + *, + callbacks: Sequence[WorkflowCallback], + ) -> Generator[GraphEngineEvent, None, None]: + """ + :param callbacks: workflow callbacks + """ + graph_engine = self.graph_engine + + try: + # run workflow + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event( + event=event + ) + yield event + except GenerateTaskStoppedException: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + if callbacks: + for callback in callbacks: + callback.on_event( + event=GraphRunFailedEvent( + error=str(e) + ) + ) + return + + @classmethod + def single_step_run( + cls, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict + ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get('id') == node_id: + node_config = node + break + + if not node_config: + raise ValueError('node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + if not node_cls: + raise ValueError(f'Node class not found for node type {node_type}') + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + # init graph + graph = Graph.init( + graph_config=workflow.graph_dict + ) + + # init workflow run state + node_instance: BaseNode = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ), + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter() + ) + ) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=node_instance.node_data + ) + + # run node + generator = node_instance.run() + + return node_instance, generator + except Exception as e: + raise WorkflowNodeRunFailedError( + node_instance=node_instance, + error=str(e) + ) + + @classmethod + def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: + """ + Handle special values + :param value: value + :return: + """ + if not value: + return None + + new_value = dict(value) if value else {} + if isinstance(new_value, dict): + for key, val in new_value.items(): + if isinstance(val, FileVar): + new_value[key] = val.to_dict() + elif isinstance(val, list): + new_val = [] + for v in val: + if isinstance(v, FileVar): + new_val.append(v.to_dict()) + else: + new_val.append(v) + + new_value[key] = new_val + + return new_value + + @classmethod + def mapping_user_inputs_to_variable_pool( + cls, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_type: NodeType, + node_data: BaseNodeData + ) -> None: + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split('.') + if len(node_variable_list) < 1: + raise ValueError(f'Invalid node variable {node_variable}') + + node_variable_key = '.'.join(node_variable_list[1:]) + + if ( + node_variable_key not in user_inputs + and node_variable not in user_inputs + ) and not variable_pool.get(variable_selector): + raise ValueError(f'Variable key {node_variable} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + variable_key_list = cast(list[str], variable_key_list) + + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) + + # FIXME: temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(input_value, list): + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in input_value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + + # append variable and value to variable pool + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 0000000000..55824945da --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: 030f4915f36a +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = '030f4915f36a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index cdd5e1992d..e78b5666bc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model): 'triggered_from', 'workflow_run_id'), db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'), + db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_execution_id'), ) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) @@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model): workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) + node_execution_id = db.Column(db.String(255), nullable=True) node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 895855a9c8..73c446b83b 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) -current_dsl_version = "0.1.1" +current_dsl_version = "0.1.2" dsl_to_dify_version_mapping: dict[str, str] = { + "0.1.2": "0.8.0", "0.1.1": "0.6.0", # dsl version -> from dify version } diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 747505977f..26517a05fb 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService @@ -103,9 +104,7 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration( - cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True - ): + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( @@ -142,7 +141,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4c3ded14ad..357ffd41c1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account @@ -172,8 +174,13 @@ class WorkflowService: Get default block configs """ # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_configs() + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: """ @@ -182,11 +189,18 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType.value_of(node_type) # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_config(node_type, filters) + node_class = node_classes.get(node_type_enum) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config def run_draft_workflow_node( self, app_model: App, node_id: str, user_inputs: dict, account: Account @@ -200,82 +214,68 @@ class WorkflowService: raise ValueError("Workflow not initialized") # run draft workflow node - workflow_engine_manager = WorkflowEngineManager() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + node_instance, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ) + + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=e.node_id, - node_type=e.node_type.value, - title=e.node_title, - status=WorkflowNodeExecutionStatus.FAILED.value, - error=e.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) - db.session.add(workflow_node_execution) - db.session.commit() + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error - return workflow_node_execution + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), + workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None + workflow_node_execution.process_data = ( + json.dumps(node_run_result.process_data) if node_run_result.process_data else None ) + workflow_node_execution.outputs = ( + json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + ) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - status=node_run_result.status.value, - error=node_run_result.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error db.session.add(workflow_node_execution) db.session.commit() @@ -321,25 +321,3 @@ class WorkflowService: ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") - - @classmethod - def get_elapsed_time(cls, workflow_run_id: str) -> float: - """ - Get elapsed time - """ - elapsed_time = 0.0 - - # fetch workflow node execution by workflow_run_id - workflow_nodes = ( - db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) - .order_by(WorkflowNodeExecution.created_at.asc()) - .all() - ) - if not workflow_nodes: - return elapsed_time - - for node in workflow_nodes: - elapsed_time += node.elapsed_time - - return elapsed_time diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 6f5421e108..952c90674d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,17 +1,72 @@ +import time +import uuid from os import getenv +from typing import cast import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.nodes.code.entities import CodeNodeData +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) +def init_code_node(code_config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["code", "123", "args1"], 1) + variable_pool.add(["code", "123", "args2"], 2) + + node = CodeNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=code_config, + ) + + return node + + @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """ @@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, - }, - }, - ) - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["result"] == 3 assert result.error is None @@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "string", - }, - }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, - }, - }, - ) - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "string", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) # execute node - result = node.run(pool) - + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error == "Output variable `result` must be a string" @@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth(): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "string_validator": { - "type": "string", - }, - "number_validator": { - "type": "number", - }, - "number_array_validator": { - "type": "array[number]", - }, - "string_array_validator": { - "type": "array[string]", - }, - "object_validator": { - "type": "object", - "children": { - "result": { - "type": "number", - }, - "depth": { - "type": "object", - "children": { - "depth": { - "type": "object", - "children": { - "depth": { - "type": "number", - } - }, - } - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + }, + } }, }, }, }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) @@ -250,35 +284,30 @@ def test_execute_code_output_object_list(): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, - config={ - "id": "1", - "data": { - "outputs": { - "object_list": { - "type": "array[object]", - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "object_list": { + "type": "array[object]", }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -295,6 +324,8 @@ def test_execute_code_output_object_list(): ] } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index acb616b325..65aaa0bddd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,31 +1,69 @@ +import time +import uuid from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock -BASIC_NODE_DATA = { - "tenant_id": "1", - "app_id": "1", - "workflow_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.WEB_APP, -} -# construct variable pool -pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(["a", "b123", "args1"], 1) -pool.add(["a", "b123", "args2"], 2) +def init_http_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return HttpRequestNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -45,12 +83,11 @@ def test_get(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -59,7 +96,7 @@ def test_get(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_no_auth(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -143,11 +178,11 @@ def test_template(setup_http_mock): "params": "A:b\nTemplate:{{#a.b123.args2#}}", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -158,7 +193,7 @@ def test_template(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -178,11 +213,11 @@ def test_json(setup_http_mock): "params": "A:b", "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert '{"a": "1"}' in data @@ -190,7 +225,7 @@ def test_json(setup_http_mock): def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock): "params": "A:b", "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "a=1&b=2" in data @@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock): def test_form_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -242,11 +277,11 @@ def test_form_data(setup_http_mock): "params": "A:b", "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert 'form-data; name="a"' in data @@ -257,7 +292,7 @@ def test_form_data(setup_http_mock): def test_none_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -277,11 +312,11 @@ def test_none_data(setup_http_mock): "params": "A:b", "body": {"type": "none", "data": "123123123"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "X-Header: 123" in data @@ -289,7 +324,7 @@ def test_none_data(setup_http_mock): def test_mock_404(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock): "params": "", "headers": "X-Header:123", }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.outputs is not None resp = result.outputs assert 404 == resp.get("status_code") - assert "Not Found" in resp.get("body") + assert "Not Found" in resp.get("body", "") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock): "headers": "Referer:http://example3.com\nRedirect:http://example4.com", "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None + assert result.outputs is not None resp = result.outputs - assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") - assert "http://example3.com" == resp.get("headers").get("referer") + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "") + assert "http://example3.com" == resp.get("headers", {}).get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6bab83a019..dfb43650d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,5 +1,8 @@ import json import os +import time +import uuid +from collections.abc import Generator from unittest.mock import MagicMock import pytest @@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db from models.provider import ProviderType -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_execute_llm(setup_openai_mock): - node = LLMNode( +def init_llm_node(config: dict) -> LLMNode: + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["abc", "output"], "sunny") + + node = LLMNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + return node + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_execute_llm(setup_openai_mock): + node = init_llm_node( config={ "id": "llm", "data": { @@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock): model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, Generator) - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["text"] is not None - assert result.outputs["usage"]["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ - node = LLMNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_llm_node( config={ "id": "llm", "data": { @@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") - + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) + result = node._run() - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "sunny" in json.dumps(result.process_data) - assert "what's the weather today?" in json.dumps(result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ca2bae5c53..cbe9c5914f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,5 +1,7 @@ import json import os +import time +import uuid from typing import Optional from unittest.mock import MagicMock @@ -8,19 +10,21 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db from models.provider import ProviderType """FOR MOCK FIXTURES, DO NOT REMOVE""" -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -47,13 +51,15 @@ def get_mocked_fetch_model_config( model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_schema = model_type_instance.get_model_schema(model) + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model=model, provider=provider, mode=mode, credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema(model), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str): return MagicMock(return_value=MemoryMock()) +def init_parameter_extractor_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return ParameterExtractorNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock): "reasoning_mode": "function_call", "memory": None, }, - }, + } ) node._fetch_model_config = get_mocked_fetch_model_config( @@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): environment_variables=[], ) - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None @@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None process_data = result.process_data + assert process_data is not None process_data.get("prompts") - for prompt in process_data.get("prompts"): + for prompt in process_data.get("prompts", []): if prompt.get("role") == "system": assert "what's the weather in SF" in prompt.get("text") @@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) for prompt in prompts: if prompt.get("role") == "user": @@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - assert len(result.process_data.get("prompts")) == 1 - assert "SF" in result.process_data.get("prompts")[0].get("text") + assert result.process_data is not None + assert len(result.process_data.get("prompts", [])) == 1 + assert "SF" in result.process_data.get("prompts", [])[0].get("text") def test_extract_json_response(): @@ -322,13 +325,7 @@ def test_extract_json_response(): Test extract json response. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -357,6 +354,7 @@ def test_extract_json_response(): hello world. """) + assert result is not None assert result["location"] == "kawaii" @@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) latest_role = None for prompt in prompts: diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 617b6370c9..073c4bb799 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,46 +1,84 @@ +import time +import uuid + import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """{{args2}}""" - node = TemplateTransformNode( + config = { + "id": "1", + "data": { + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "template": code, + }, + } + + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.END_USER, - config={ - "id": "1", - "data": { - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "template": code, - }, - }, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 3) + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["1", "123", "args1"], 1) + variable_pool.add(["1", "123", "args2"], 3) + + node = TemplateTransformNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) # execute node - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 29c1efa8e7..4d94cdb28a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,21 +1,62 @@ +import time +import uuid + from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.tool.tool_node import ToolNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def init_tool_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + return ToolNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) def test_tool_variable_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], "1+1") - - node = ToolNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_tool_node( config={ "id": "1", "data": { @@ -34,28 +75,22 @@ def test_tool_variable_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] def test_tool_mixed_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "args1"], "1+1") - - node = ToolNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_tool_node( config={ "id": "1", "data": { @@ -74,12 +109,15 @@ def test_tool_mixed_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index afc9802cf1..ca3082953a 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -1,7 +1,24 @@ import os +import pytest +from flask import Flask + # Getting the absolute path of the current file's directory ABS_PATH = os.path.dirname(os.path.abspath(__file__)) # Getting the absolute path of the project's root directory PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + +CACHED_APP = Flask(__name__) +CACHED_APP.config.update({"TESTING": True}) + + +@pytest.fixture() +def app() -> Flask: + return CACHED_APP + + +@pytest.fixture(autouse=True) +def _provide_app_context(app: Flask): + with app.app_context(): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py new file mode 100644 index 0000000000..65757cd604 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -0,0 +1,791 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "question-classifier"}, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_iteration_graph(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "iteration"}, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + }, + ], + } + + graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration") + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")], + ), + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" + + +def test_parallels_graph(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + start_edges = graph.edge_mapping.get("start") + assert start_edges is not None + assert start_edges[i].target_node_id == f"llm{i+1}" + + llm_edges = graph.edge_mapping.get(f"llm{i+1}") + assert llm_edges is not None + assert llm_edges[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph2(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + if i < 2: + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph3(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph4(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "code2", + }, + { + "id": "llm3-source-code3-target", + "source": "llm3", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" + assert graph.edge_mapping.get(f"code{i + 1}") is not None + assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph5(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm4", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm5", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-code1-target", + "source": "llm2", + "target": "code1", + }, + { + "id": "llm3-source-code2-target", + "source": "llm3", + "target": "code2", + }, + { + "id": "llm4-source-code2-target", + "source": "llm4", + "target": "code2", + }, + { + "id": "llm5-source-code3-target", + "source": "llm5", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(5): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm3") is not None + assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm4") is not None + assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm5") is not None + assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 8 + + for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph6(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm1-source-code2-target", + "source": "llm1", + "target": "code2", + }, + { + "id": "llm2-source-code3-target", + "source": "llm2", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code3") is not None + assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 2 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + parent_parallel = None + child_parallel = None + for p_id, parallel in graph.parallel_mapping.items(): + if parallel.parent_parallel_id is None: + parent_parallel = parallel + else: + child_parallel = parallel + + for node_id in ["llm1", "llm2", "llm3", "code3"]: + assert graph.node_parallel_mapping[node_id] == parent_parallel.id + + for node_id in ["code1", "code2"]: + assert graph.node_parallel_mapping[node_id] == child_parallel.id diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py new file mode 100644 index 0000000000..a2d71d61fc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -0,0 +1,505 @@ +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + BaseNodeEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.llm_node import LLMNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_workflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "llm1", + }, + { + "id": "2", + "source": "llm1", + "target": "llm2", + }, + { + "id": "3", + "source": "llm1", + "target": "llm3", + }, + { + "id": "4", + "source": "llm2", + "target": "end1", + }, + { + "id": "5", + "source": "llm3", + "target": "end2", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + "title": "start", + "variables": [ + { + "label": "query", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "query", + } + ], + }, + "id": "start", + }, + { + "data": { + "type": "llm", + "title": "llm1", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say hi"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + "title": "llm2", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say bye"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + "title": "llm3", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say good morning"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm3", + }, + { + "data": { + "type": "end", + "title": "end1", + "outputs": [ + {"value_selector": ["llm2", "text"], "variable": "result2"}, + {"value_selector": ["start", "query"], "variable": "query"}, + ], + }, + "id": "end1", + }, + { + "data": { + "type": "end", + "title": "end2", + "outputs": [ + {"value_selector": ["llm1", "text"], "variable": "result1"}, + {"value_selector": ["llm3", "text"], "variable": "result3"}, + ], + }, + "id": "end2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent( + chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"] + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + # print("") + + with patch.object(LLMNode, "_run", new=llm_generator): + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + assert item.parallel_id is not None + + assert len(items) == 18 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_chatflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "answer1", + }, + { + "id": "2", + "source": "answer1", + "target": "answer2", + }, + { + "id": "3", + "source": "answer1", + "target": "answer3", + }, + { + "id": "4", + "source": "answer2", + "target": "answer4", + }, + { + "id": "5", + "source": "answer3", + "target": "answer5", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "start"}, "id": "start"}, + {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"}, + { + "data": {"type": "answer", "title": "answer2", "answer": "2"}, + "id": "answer2", + }, + { + "data": {"type": "answer", "title": "answer3", "answer": "3"}, + "id": "answer3", + }, + { + "data": {"type": "answer", "title": "answer4", "answer": "4"}, + "id": "answer4", + }, + { + "data": {"type": "answer", "title": "answer5", "answer": "5"}, + "id": "answer5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + "answer2", + "answer3", + "answer4", + "answer5", + ]: + assert item.parallel_id is not None + + assert len(items) == 23 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_branch(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "if-else-1", + }, + { + "id": "2", + "source": "if-else-1", + "sourceHandle": "true", + "target": "answer-1", + }, + { + "id": "3", + "source": "if-else-1", + "sourceHandle": "false", + "target": "if-else-2", + }, + { + "id": "4", + "source": "if-else-2", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "5", + "source": "if-else-2", + "sourceHandle": "false", + "target": "answer-3", + }, + ], + "nodes": [ + { + "data": { + "title": "Start", + "type": "start", + "variables": [ + { + "label": "uid", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "uid", + } + ], + }, + "id": "start", + }, + { + "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []}, + "id": "answer-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", + "value": "hi", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "title": "IF/ELSE", + "type": "if-else", + }, + "id": "if-else-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "ae895199-5608-433b-b5f0-0997ae1431e4", + "value": "takatost", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "title": "IF/ELSE 2", + "type": "if-else", + }, + "id": "if-else-2", + }, + { + "data": { + "answer": "2", + "title": "Answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "answer": "3", + "title": "Answer 3", + "type": "answer", + }, + "id": "answer-3", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "hi", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + + assert len(items) == 10 + assert items[3].route_node_state.node_id == "if-else-1" + assert items[4].route_node_state.node_id == "if-else-1" + assert isinstance(items[5], NodeRunStreamChunkEvent) + assert items[5].chunk_content == "1 " + assert isinstance(items[6], NodeRunStreamChunkEvent) + assert items[6].chunk_content == "takato" + assert items[7].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato" + assert isinstance(items[9], GraphRunSucceededEvent) + + # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py b/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py new file mode 100644 index 0000000000..fe4ede6335 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -0,0 +1,82 @@ +import time +import uuid +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.answer.answer_node import AnswerNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_answer(): + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py new file mode 100644 index 0000000000..bce87536d8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py @@ -0,0 +1,109 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + answer_stream_generate_route = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping + ) + + assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"] + assert answer_stream_generate_route.answer_dependencies["answer2"] == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py new file mode 100644 index 0000000000..6b1d1e9070 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -0,0 +1,216 @@ +import uuid +from collections.abc import Generator +from datetime import datetime, timezone + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.start.entities import StartNodeData + + +def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + if next_node_id == "start": + yield from _publish_events(graph, next_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _publish_events(graph, edge.target_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _recursive_process(graph, edge.target_node_id) + + +def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + + parallel_id = graph.node_parallel_mapping.get(next_node_id) + parallel_start_node_id = None + if parallel_id: + parallel = graph.parallel_mapping.get(parallel_id) + parallel_start_node_id = parallel.start_from_node_id if parallel else None + + node_execution_id = str(uuid.uuid4()) + node_config = graph.node_id_config_mapping[next_node_id] + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + mock_node_data = StartNodeData(**{"title": "demo", "variables": []}) + + yield NodeRunStartedEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_start_node_id=parallel_start_node_id, + ) + + if "llm" in next_node_id: + length = int(next_node_id[-1]) + for i in range(0, length): + yield NodeRunStreamChunkEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + chunk_content=str(i), + route_node_state=route_node_state, + from_variable_selector=[next_node_id, "text"], + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + route_node_state.status = RouteNodeState.Status.SUCCESS + route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + yield NodeRunSucceededEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + +def test_process(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool) + + def graph_generator() -> Generator[GraphEngineEvent, None, None]: + # print("") + for event in _recursive_process(graph, "start"): + # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunSucceededEvent): + if "llm" in event.route_node_state.node_id: + variable_pool.add( + [event.route_node_state.node_id, "text"], + "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))), + ) + yield event + + result_generator = answer_stream_processor.process(graph_generator()) + stream_contents = "" + for event in result_generator: + # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunStreamChunkEvent): + stream_contents += event.chunk_content + pass + + assert stream_contents == "c012da01b" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py new file mode 100644 index 0000000000..b3a89061b2 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -0,0 +1,420 @@ +import time +import uuid +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_run(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 20 + + +def test_run_parallel(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 32 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 8020674ee6..cb2e99a854 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,22 +1,70 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_answer(): - node = AnswerNode( + graph_config = { + "edges": [ + { + "id": "start-source-answer-target", + "source": "start", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "answer", "data": { @@ -27,20 +75,11 @@ def test_execute_answer(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "weather"], "sunny") - pool.add(["llm", "text"], "You are a helpful AI.") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 9535bc2186..0795f134d0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,22 +1,63 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_if_else_result_true(): - node = IfElseNode( + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -63,48 +104,64 @@ def test_execute_if_else_result_true(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["ab", "def"]) - pool.add(["start", "array_not_contains"], ["ac", "def"]) - pool.add(["start", "contains"], "cabcde") - pool.add(["start", "not_contains"], "zacde") - pool.add(["start", "start_with"], "abc") - pool.add(["start", "end_with"], "zzab") - pool.add(["start", "is"], "ab") - pool.add(["start", "is_not"], "aab") - pool.add(["start", "empty"], "") - pool.add(["start", "not_empty"], "aaa") - pool.add(["start", "equals"], 22) - pool.add(["start", "not_equals"], 23) - pool.add(["start", "greater_than"], 23) - pool.add(["start", "less_than"], 21) - pool.add(["start", "greater_than_or_equal"], 22) - pool.add(["start", "less_than_or_equal"], 21) - pool.add(["start", "not_null"], "1212") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is True def test_execute_if_else_result_false(): - node = IfElseNode( + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -127,20 +184,11 @@ def test_execute_if_else_result_false(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["1ab", "def"]) - pool.add(["start", "array_not_contains"], ["ab", "def"]) - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index e26c7df642..f45a93f1be 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -1,17 +1,56 @@ +import time +import uuid from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode +from models.workflow import WorkflowType DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = StringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -24,13 +63,24 @@ def test_overwrite_string_variable(): value="the second value", ) + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "node_id", "data": { @@ -41,19 +91,8 @@ def test_overwrite_string_variable(): }, ) - variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -63,6 +102,39 @@ def test_overwrite_string_variable(): def test_append_variable_to_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -75,23 +147,6 @@ def test_append_variable_to_array(): value="the second value", ) - node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config={ - "id": "node_id", - "data": { - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, - ) - variable_pool = VariablePool( system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, @@ -103,8 +158,23 @@ def test_append_variable_to_array(): input_variable, ) + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -113,19 +183,57 @@ def test_append_variable_to_array(): def test_clear_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", value=["the first value"], ) + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "node_id", "data": { @@ -136,14 +244,9 @@ def test_clear_array(): }, ) - variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - node.run(variable_pool) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index bc42d4b8f7..7075a31f2b 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -396,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.3 + image: langgenius/dify-web:0.8.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aca87a6f8f..68f897ddb9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -208,7 +208,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -228,7 +228,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.3 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -247,7 +247,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.3 + image: langgenius/dify-web:0.8.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 5f36e40c40..1f17798f83 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -11,10 +11,10 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import type { ChatItem, WorkflowProcess } from '../../types' +import TracingPanel from '@/app/components/workflow/run/tracing-panel' import cn from '@/utils/classnames' import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general' import { WorkflowRunningStatus } from '@/app/components/workflow/types' -import NodePanel from '@/app/components/workflow/run/node' import { useStore as useAppStore } from '@/app/components/app/store' type WorkflowProcessProps = { @@ -107,16 +107,12 @@ const WorkflowProcessItem = ({ !collapse && (
{ - data.tracing.map(node => ( -
- -
- )) + }
) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index e6638b8eed..892f88c4ad 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -241,8 +241,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) hasStopResponded.current = false @@ -503,12 +501,13 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], + const iterationIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationIndex] = { + ...tracing[iterationIndex], ...data, status: WorkflowRunningStatus.Succeeded, } as any @@ -520,10 +519,9 @@ export const useChat = ( ...responseItem, } })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return responseItem.workflowProcess!.tracing!.push({ @@ -539,10 +537,15 @@ export const useChat = ( })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata.parallel_id) + }) responseItem.workflowProcess!.tracing[currentIndex] = data as any handleUpdateChatList(produce(chatListRef.current, (draft) => { const currentIndex = draft.findIndex(item => item.id === responseItem.id) diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 021a98d4d3..96fe9f01ef 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -196,8 +196,6 @@ const Result: FC = ({ })() if (isWorkflow) { - let isInIteration = false - sendWorkflowMessage( data, { @@ -219,23 +217,28 @@ const Result: FC = ({ expand: true, } as any) })) - isInIteration = true }, onIterationNext: () => { + setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + draft.expand = true + const iterations = draft.tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + iterations?.details!.push([]) + })) }, onIterationFinish: ({ data }) => { setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true - // const iteration = draft.tracing![draft.tracing!.length - 1] - draft.tracing![draft.tracing!.length - 1] = { + const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + draft.tracing[iterationsIndex] = { ...data, expand: !!data.error, } as any })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { @@ -248,11 +251,12 @@ const Result: FC = ({ })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) + const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id + && (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id)) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras diff --git a/web/app/components/workflow/candidate-node.tsx b/web/app/components/workflow/candidate-node.tsx index faae545b3b..16d6f852b2 100644 --- a/web/app/components/workflow/candidate-node.tsx +++ b/web/app/components/workflow/candidate-node.tsx @@ -14,9 +14,11 @@ import { } from './store' import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks' import { CUSTOM_NODE } from './constants' +import { getIterationStartNode } from './utils' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' +import { BlockEnum } from './types' const CandidateNode = () => { const store = useStoreApi() @@ -52,6 +54,8 @@ const CandidateNode = () => { y, }, }) + if (candidateNode.data.type === BlockEnum.Iteration) + draft.push(getIterationStartNode(candidateNode.id)) }) setNodes(newNodes) if (candidateNode.type === CUSTOM_NOTE_NODE) diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 070748bab0..6a4629e9c8 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -15,6 +15,7 @@ import VariableAssignerDefault from './nodes/variable-assigner/default' import AssignerDefault from './nodes/assigner/default' import EndNodeDefault from './nodes/end/default' import IterationDefault from './nodes/iteration/default' +import IterationStartDefault from './nodes/iteration-start/default' type NodesExtraData = { author: string @@ -89,6 +90,15 @@ export const NODES_EXTRA_DATA: Record = { getAvailableNextNodes: IterationDefault.getAvailableNextNodes, checkValid: IterationDefault.checkValid, }, + [BlockEnum.IterationStart]: { + author: 'Dify', + about: '', + availablePrevNodes: [], + availableNextNodes: [], + getAvailablePrevNodes: IterationStartDefault.getAvailablePrevNodes, + getAvailableNextNodes: IterationStartDefault.getAvailableNextNodes, + checkValid: IterationStartDefault.checkValid, + }, [BlockEnum.Code]: { author: 'Dify', about: '', @@ -222,6 +232,12 @@ export const NODES_INITIAL_DATA = { desc: '', ...IterationDefault.defaultValue, }, + [BlockEnum.IterationStart]: { + type: BlockEnum.IterationStart, + title: '', + desc: '', + ...IterationStartDefault.defaultValue, + }, [BlockEnum.Code]: { type: BlockEnum.Code, title: '', @@ -305,11 +321,13 @@ export const AUTO_LAYOUT_OFFSET = { export const ITERATION_NODE_Z_INDEX = 1 export const ITERATION_CHILDREN_Z_INDEX = 1002 export const ITERATION_PADDING = { - top: 85, + top: 65, right: 16, bottom: 20, left: 16, } +export const PARALLEL_LIMIT = 10 +export const PARALLEL_DEPTH_LIMIT = 3 export const RETRIEVAL_OUTPUT_STRUCT = `{ "content": "", @@ -412,4 +430,5 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' +export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 3645e18449..af2a1500ba 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -16,6 +16,7 @@ import { useReactFlow, useStoreApi, } from 'reactflow' +import { unionBy } from 'lodash-es' import type { ToolDefaultValue } from '../block-selector/types' import type { Edge, @@ -25,6 +26,7 @@ import type { import { BlockEnum } from '../types' import { useWorkflowStore } from '../store' import { + CUSTOM_EDGE, ITERATION_CHILDREN_Z_INDEX, ITERATION_PADDING, NODES_INITIAL_DATA, @@ -40,6 +42,7 @@ import { } from '../utils' import { CUSTOM_NOTE_NODE } from '../note-node/constants' import type { IterationNodeType } from '../nodes/iteration/types' +import { CUSTOM_ITERATION_START_NODE } from '../nodes/iteration-start/constants' import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types' import { useNodeIterationInteractions } from '../nodes/iteration/use-interactions' import { useWorkflowHistoryStore } from '../workflow-history-store' @@ -60,6 +63,7 @@ export const useNodesInteractions = () => { const { store: workflowHistoryStore } = useWorkflowHistoryStore() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { + checkNestedParallelLimit, getAfterNodesInSameBranch, } = useWorkflow() const { getNodesReadOnly } = useNodesReadOnly() @@ -79,7 +83,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart || node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_ITERATION_START_NODE || node.type === CUSTOM_NOTE_NODE) return dragNodeStartPosition.current = { x: node.position.x, y: node.position.y } @@ -89,7 +93,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart) + if (node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -156,7 +160,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -207,13 +211,30 @@ export const useNodesInteractions = () => { }) }) setEdges(newEdges) + const connectedEdges = getConnectedEdges([node], edges).filter(edge => edge.target === node.id) + + const targetNodes: Node[] = [] + for (let i = 0; i < connectedEdges.length; i++) { + const sourceConnectedEdges = getConnectedEdges([{ id: connectedEdges[i].source } as Node], edges).filter(edge => edge.source === connectedEdges[i].source && edge.sourceHandle === connectedEdges[i].sourceHandle) + targetNodes.push(...sourceConnectedEdges.map(edge => nodes.find(n => n.id === edge.target)!)) + } + const uniqTargetNodes = unionBy(targetNodes, 'id') + if (uniqTargetNodes.length > 1) { + const newNodes = produce(nodes, (draft) => { + draft.forEach((n) => { + if (uniqTargetNodes.some(targetNode => n.id === targetNode.id)) + n.data._inParallelHovering = true + }) + }) + setNodes(newNodes) + } }, [store, workflowStore, getNodesReadOnly]) const handleNodeLeave = useCallback((_, node) => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -229,6 +250,7 @@ export const useNodesInteractions = () => { const newNodes = produce(getNodes(), (draft) => { draft.forEach((node) => { node.data._isEntering = false + node.data._inParallelHovering = false }) }) setNodes(newNodes) @@ -287,6 +309,8 @@ export const useNodesInteractions = () => { }, [store, handleSyncWorkflowDraft]) const handleNodeClick = useCallback((_, node) => { + if (node.type === CUSTOM_ITERATION_START_NODE) + return handleNodeSelect(node.id) }, [handleNodeSelect]) @@ -314,25 +338,15 @@ export const useNodesInteractions = () => { if (targetNode?.parentId !== sourceNode?.parentId) return - if (targetNode?.data.isIterationStart) - return - if (sourceNode?.type === CUSTOM_NOTE_NODE || targetNode?.type === CUSTOM_NOTE_NODE) return - const needDeleteEdges = edges.filter((edge) => { - if ( - (edge.source === source && edge.sourceHandle === sourceHandle) - || (edge.target === target && edge.targetHandle === targetHandle && targetNode?.data.type !== BlockEnum.VariableAssigner && targetNode?.data.type !== BlockEnum.VariableAggregator) - ) - return true + if (edges.find(edge => edge.source === source && edge.sourceHandle === sourceHandle && edge.target === target && edge.targetHandle === targetHandle)) + return - return false - }) - const needDeleteEdgesIds = needDeleteEdges.map(edge => edge.id) const newEdge = { id: `${source}-${sourceHandle}-${target}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: source!, target: target!, sourceHandle, @@ -347,7 +361,6 @@ export const useNodesInteractions = () => { } const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( [ - ...needDeleteEdges.map(edge => ({ type: 'remove', edge })), { type: 'add', edge: newEdge }, ], nodes, @@ -362,19 +375,26 @@ export const useNodesInteractions = () => { } }) }) - setNodes(newNodes) const newEdges = produce(edges, (draft) => { - const filtered = draft.filter(edge => !needDeleteEdgesIds.includes(edge.id)) - - filtered.push(newEdge) - - return filtered + draft.push(newEdge) }) - setEdges(newEdges) - handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) - }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + } + else { + const { + setConnectingNodePayload, + setEnteringNodePayload, + } = workflowStore.getState() + setConnectingNodePayload(undefined) + setEnteringNodePayload(undefined) + } + }, [getNodesReadOnly, store, workflowStore, handleSyncWorkflowDraft, saveStateToHistory, checkNestedParallelLimit]) const handleNodeConnectStart = useCallback((_, { nodeId, handleType, handleId }) => { if (getNodesReadOnly()) @@ -393,14 +413,12 @@ export const useNodesInteractions = () => { return } - if (!node.data.isIterationStart) { - setConnectingNodePayload({ - nodeId, - nodeType: node.data.type, - handleType, - handleId, - }) - } + setConnectingNodePayload({ + nodeId, + nodeType: node.data.type, + handleType, + handleId, + }) } }, [store, workflowStore, getNodesReadOnly]) @@ -510,6 +528,12 @@ export const useNodesInteractions = () => { return handleNodeDelete(nodeId) } else { + if (iterationChildren.length === 1) { + handleNodeDelete(iterationChildren[0].id) + handleNodeDelete(nodeId) + + return + } const { setShowConfirm, showConfirm } = workflowStore.getState() if (!showConfirm) { @@ -541,14 +565,8 @@ export const useNodesInteractions = () => { } } - if (node.id === currentNode.parentId) { + if (node.id === currentNode.parentId) node.data._children = node.data._children?.filter(child => child !== nodeId) - - if (currentNode.id === (node as Node).data.start_node_id) { - (node as Node).data.start_node_id = ''; - (node as Node).data.startNodeType = undefined - } - } }) draft.splice(currentNodeIndex, 1) }) @@ -559,7 +577,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - if (currentNode.type === 'custom-note') + if (currentNode.type === CUSTOM_NOTE_NODE) saveStateToHistory(WorkflowHistoryEvent.NoteDelete) else @@ -591,7 +609,10 @@ export const useNodesInteractions = () => { } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -627,7 +648,7 @@ export const useNodesInteractions = () => { const newEdge: Edge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -662,8 +683,10 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) + if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { const { setShowAssignVariablePopup } = workflowStore.getState() @@ -687,7 +710,14 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } } if (!prevNodeId && nextNodeId) { const nextNodeIndex = nodes.findIndex(node => node.id === nextNodeId) @@ -706,15 +736,13 @@ export const useNodesInteractions = () => { newNode.data.iteration_id = nextNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX } - if (nextNode.data.isIterationStart) - newNode.data.isIterationStart = true let newEdge if ((nodeType !== BlockEnum.IfElse) && (nodeType !== BlockEnum.QuestionClassifier)) { newEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -763,13 +791,11 @@ export const useNodesInteractions = () => { node.data.start_node_id = newNode.id node.data.startNodeType = newNode.data.type } - - if (node.id === nextNodeId && node.data.isIterationStart) - node.data.isIterationStart = false }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) if (newEdge) { const newEdges = produce(edges, (draft) => { draft.forEach((item) => { @@ -780,7 +806,21 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } + } + else { + if (checkNestedParallelLimit(newNodes, edges)) + setNodes(newNodes) + + else + return false } } if (prevNodeId && nextNodeId) { @@ -804,7 +844,7 @@ export const useNodesInteractions = () => { const currentEdgeIndex = edges.findIndex(edge => edge.source === prevNodeId && edge.target === nextNodeId) const newPrevEdge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -822,7 +862,7 @@ export const useNodesInteractions = () => { if (nodeType !== BlockEnum.IfElse && nodeType !== BlockEnum.QuestionClassifier) { newNextEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -865,6 +905,8 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { @@ -898,7 +940,7 @@ export const useNodesInteractions = () => { } handleSyncWorkflowDraft() saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch]) + }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit]) const handleNodeChange = useCallback(( currentNodeId: string, @@ -919,7 +961,10 @@ export const useNodesInteractions = () => { const currentNode = nodes.find(node => node.id === currentNodeId)! const connectedEdges = getConnectedEdges([currentNode], edges) const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newCurrentNode = generateNewNode({ + const { + newNode: newCurrentNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -929,7 +974,6 @@ export const useNodesInteractions = () => { selected: currentNode.data.selected, isInIteration: currentNode.data.isInIteration, iteration_id: currentNode.data.iteration_id, - isIterationStart: currentNode.data.isIterationStart, }, position: { x: currentNode.position.x, @@ -955,18 +999,12 @@ export const useNodesInteractions = () => { ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], } } - if (node.id === currentNode.parentId && currentNode.data.isIterationStart) { - node.data._children = [ - newCurrentNode.id, - ...(node.data._children || []), - ].filter(child => child !== currentNodeId) - node.data.start_node_id = newCurrentNode.id - node.data.startNodeType = newCurrentNode.data.type - } }) const index = draft.findIndex(node => node.id === currentNodeId) draft.splice(index, 1, newCurrentNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) const newEdges = produce(edges, (draft) => { @@ -1011,7 +1049,7 @@ export const useNodesInteractions = () => { }, [store]) const handleNodeContextMenu = useCallback((e: MouseEvent, node: Node) => { - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return e.preventDefault() @@ -1041,7 +1079,7 @@ export const useNodesInteractions = () => { if (nodeId) { // If nodeId is provided, copy that specific node - const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start) + const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start && node.type !== CUSTOM_ITERATION_START_NODE) if (nodeToCopy) setClipboardElements([nodeToCopy]) } @@ -1087,7 +1125,10 @@ export const useNodesInteractions = () => { clipboardElements.forEach((nodeToPaste, index) => { const nodeType = nodeToPaste.data.type - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ type: nodeToPaste.type, data: { ...NODES_INITIAL_DATA[nodeType], @@ -1106,24 +1147,17 @@ export const useNodesInteractions = () => { zIndex: nodeToPaste.zIndex, }) newNode.id = newNode.id + index - - // If only the iteration start node is copied, remove the isIterationStart flag // This new node is movable and can be placed anywhere - if (clipboardElements.length === 1 && newNode.data.isIterationStart) - newNode.data.isIterationStart = false - let newChildren: Node[] = [] if (nodeToPaste.data.type === BlockEnum.Iteration) { - newNode.data._children = []; - (newNode.data as IterationNodeType).start_node_id = '' + newIterationStartNode!.parentId = newNode.id; + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode!.id newChildren = handleNodeIterationChildrenCopy(nodeToPaste.id, newNode.id) - newChildren.forEach((child) => { newNode.data._children?.push(child.id) - if (child.data.isIterationStart) - (newNode.data as IterationNodeType).start_node_id = child.id }) + newChildren.push(newIterationStartNode!) } nodesToPaste.push(newNode) @@ -1230,6 +1264,42 @@ export const useNodesInteractions = () => { saveStateToHistory(WorkflowHistoryEvent.NodeResize) }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + const handleNodeDisconnect = useCallback((nodeId: string) => { + if (getNodesReadOnly()) + return + + const { + getNodes, + setNodes, + edges, + setEdges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId)! + const connectedEdges = getConnectedEdges([currentNode], edges) + const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( + connectedEdges.map(edge => ({ type: 'remove', edge })), + nodes, + ) + const newNodes = produce(nodes, (draft: Node[]) => { + draft.forEach((node) => { + if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) { + node.data = { + ...node.data, + ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], + } + } + }) + }) + setNodes(newNodes) + const newEdges = produce(edges, (draft) => { + return draft.filter(edge => !connectedEdges.find(connectedEdge => connectedEdge.id === edge.id)) + }) + setEdges(newEdges) + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.EdgeDelete) + }, [store, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory]) + const handleHistoryBack = useCallback(() => { if (getNodesReadOnly() || getWorkflowReadOnly()) return @@ -1282,6 +1352,7 @@ export const useNodesInteractions = () => { handleNodesDuplicate, handleNodesDelete, handleNodeResize, + handleNodeDisconnect, handleHistoryBack, handleHistoryForward, } diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index a872aee398..e1da503f38 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -1,5 +1,6 @@ import { useCallback } from 'react' import { + getIncomers, useReactFlow, useStoreApi, } from 'reactflow' @@ -8,6 +9,7 @@ import { v4 as uuidV4 } from 'uuid' import { usePathname } from 'next/navigation' import { useWorkflowStore } from '../store' import { useNodesSyncDraft } from '../hooks' +import type { Node } from '../types' import { NodeRunningStatus, WorkflowRunningStatus, @@ -140,9 +142,6 @@ export const useWorkflowRun = () => { resultText: '', }) - let isInIteration = false - let iterationLength = 0 - let ttsUrl = '' let ttsIsPublic = false if (params.token) { @@ -186,7 +185,7 @@ export const useWorkflowRun = () => { draft.forEach((edge) => { edge.data = { ...edge.data, - _run: false, + _runned: false, } }) }) @@ -249,19 +248,20 @@ export const useWorkflowRun = () => { setEdges, transform, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const node = nodes.find(node => node.id === data.node_id) + if (node?.parentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ + const iterations = tracing.find(trace => trace.node_id === node?.parentId) + const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] + currIteration?.push({ ...data, status: NodeRunningStatus.Running, } as any) })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -288,11 +288,12 @@ export const useWorkflowRun = () => { draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running }) setNodes(newNodes) + const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id) const newEdges = produce(edges, (draft) => { - const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) - - if (edge) - edge.data = { ...edge.data, _run: true } as any + draft.forEach((edge) => { + if (edge.target === data.node_id && incomeNodesId.includes(edge.source)) + edge.data = { ...edge.data, _runned: true } as any + }) }) setEdges(newEdges) } @@ -309,25 +310,46 @@ export const useWorkflowRun = () => { getNodes, setNodes, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId + if (nodeParentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - const nodeInfo = currIteration[currIteration.length - 1] + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - currIteration[currIteration.length - 1] = { - ...nodeInfo, - ...data, - status: NodeRunningStatus.Succeeded, - } as any + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] + + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) + if (data.status === NodeRunningStatus.Succeeded) { + if (nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + } })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) - + const currentIndex = draft.tracing!.findIndex((trace) => { + if (!trace.execution_metadata?.parallel_id) + return trace.node_id === data.node_id + return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id + }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras @@ -337,16 +359,14 @@ export const useWorkflowRun = () => { } as any } })) - const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - currentNode.data._runningStatus = data.status as any }) setNodes(newNodes) - prevNodeId = data.node_id } + if (onNodeFinished) onNodeFinished(params) }, @@ -371,8 +391,6 @@ export const useWorkflowRun = () => { details: [], } as any) })) - isInIteration = true - iterationLength = data.metadata.iterator_length const { setViewport, @@ -398,7 +416,7 @@ export const useWorkflowRun = () => { const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) if (edge) - edge.data = { ...edge.data, _run: true } as any + edge.data = { ...edge.data, _runned: true } as any }) setEdges(newEdges) @@ -418,13 +436,13 @@ export const useWorkflowRun = () => { } = store.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const iteration = draft.tracing![draft.tracing!.length - 1] - if (iteration.details!.length >= iterationLength) - return - - iteration.details!.push([]) + const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) + if (iteration) { + if (iteration.details!.length >= iteration.metadata.iterator_length!) + return + } + iteration?.details!.push([]) })) - const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -450,13 +468,14 @@ export const useWorkflowRun = () => { const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], - ...data, - status: NodeRunningStatus.Succeeded, - } as any + const currIterationNode = tracing.find(trace => trace.node_id === data.node_id) + if (currIterationNode) { + Object.assign(currIterationNode, { + ...data, + status: NodeRunningStatus.Succeeded, + }) + } })) - isInIteration = false const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -470,6 +489,12 @@ export const useWorkflowRun = () => { if (onIterationFinish) onIterationFinish(params) }, + onParallelBranchStarted: (params) => { + // console.log(params, 'parallel start') + }, + onParallelBranchFinished: (params) => { + // console.log(params, 'finished') + }, onTextChunk: (params) => { const { data: { text } } = params const { diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow/hooks/use-workflow-template.ts index 3af3f733f1..e36f0b61f9 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow/hooks/use-workflow-template.ts @@ -10,13 +10,13 @@ export const useWorkflowTemplate = () => { const isChatMode = useIsChatMode() const nodesInitialData = useNodesInitialData() - const startNode = generateNewNode({ + const { newNode: startNode } = generateNewNode({ data: nodesInitialData.start, position: START_INITIAL_POSITION, }) if (isChatMode) { - const llmNode = generateNewNode({ + const { newNode: llmNode } = generateNewNode({ id: 'llm', data: { ...nodesInitialData.llm, @@ -31,7 +31,7 @@ export const useWorkflowTemplate = () => { }, } as any) - const answerNode = generateNewNode({ + const { newNode: answerNode } = generateNewNode({ id: 'answer', data: { ...nodesInitialData.answer, diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index cfff4220fa..460e36ae60 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -6,6 +6,7 @@ import { } from 'react' import dayjs from 'dayjs' import { uniqBy } from 'lodash-es' +import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { getIncomers, @@ -29,6 +30,11 @@ import { useWorkflowStore, } from '../store' import { + getParallelInfo, +} from '../utils' +import { + PARALLEL_DEPTH_LIMIT, + PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' import { CUSTOM_NOTE_NODE } from '../note-node/constants' @@ -50,6 +56,7 @@ import { } from '@/service/tools' import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' +import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -58,6 +65,7 @@ export const useIsChatMode = () => { } export const useWorkflow = () => { + const { t } = useTranslation() const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() @@ -77,7 +85,7 @@ export const useWorkflow = () => { const currentNode = nodes.find(node => node.id === nodeId) if (currentNode?.parentId) - startNode = nodes.find(node => node.parentId === currentNode.parentId && node.data.isIterationStart) + startNode = nodes.find(node => node.parentId === currentNode.parentId && node.type === CUSTOM_ITERATION_START_NODE) if (!startNode) return [] @@ -275,6 +283,45 @@ export const useWorkflow = () => { return isUsed }, [isVarUsedInNodes]) + const checkParallelLimit = useCallback((nodeId: string) => { + const { + getNodes, + edges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId)! + const sourceNodeOutgoers = getOutgoers(currentNode, nodes, edges) + if (sourceNodeOutgoers.length > PARALLEL_LIMIT - 1) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) + return false + } + + return true + }, [store, workflowStore, t]) + + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { + const { + parallelList, + hasAbnormalEdges, + } = getParallelInfo(nodes, edges, parentNodeId) + + if (hasAbnormalEdges) + return false + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > PARALLEL_DEPTH_LIMIT) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) + return false + } + } + + return true + }, [t, workflowStore]) + const isValidConnection = useCallback(({ source, target }: Connection) => { const { edges, @@ -284,12 +331,15 @@ export const useWorkflow = () => { const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! - if (targetNode.data.isIterationStart) + if (!checkParallelLimit(source!)) return false if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) return false + if (sourceNode.parentId !== targetNode.parentId) + return false + if (sourceNode && targetNode) { const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start] @@ -316,7 +366,7 @@ export const useWorkflow = () => { } return !hasCycle(targetNode) - }, [store, nodesExtraData]) + }, [store, nodesExtraData, checkParallelLimit]) const formatTimeFromNow = useCallback((time: number) => { return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() @@ -339,6 +389,8 @@ export const useWorkflow = () => { isVarUsedInNodes, removeUsedVarInNodes, isNodeVarsUsedInNodes, + checkParallelLimit, + checkNestedParallelLimit, isValidConnection, formatTimeFromNow, getNode, diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index d96faa8677..cdccd60a3b 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -55,6 +55,8 @@ import Header from './header' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' +import CustomIterationStartNode from './nodes/iteration-start' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import Operator from './operator' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' @@ -67,6 +69,7 @@ import NodeContextmenu from './node-contextmenu' import SyncingDataModal from './syncing-data-modal' import UpdateDSLModal from './update-dsl-modal' import DSLExportConfirmModal from './dsl-export-confirm-modal' +import LimitTips from './limit-tips' import { useStore, useWorkflowStore, @@ -92,6 +95,7 @@ import Confirm from '@/app/components/base/confirm' const nodeTypes = { [CUSTOM_NODE]: CustomNode, [CUSTOM_NOTE_NODE]: CustomNoteNode, + [CUSTOM_ITERATION_START_NODE]: CustomIterationStartNode, } const edgeTypes = { [CUSTOM_NODE]: CustomEdge, @@ -317,6 +321,7 @@ const Workflow: FC = memo(({ /> ) } + { + const showTips = useStore(s => s.showTips) + const setShowTips = useStore(s => s.setShowTips) + + if (!showTips) + return null + + return ( +
+
+
+ +
+
+ {showTips} +
+ setShowTips('')} + > + + +
+ ) +} + +export default LimitTips diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 0ab0c8e39e..6e3988eecb 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -21,13 +21,13 @@ type AddProps = { nodeId: string nodeData: CommonNodeType sourceHandle: string - branchName?: string + isParallel?: boolean } const Add = ({ nodeId, nodeData, sourceHandle, - branchName, + isParallel, }: AddProps) => { const { t } = useTranslation() const { handleNodeAdd } = useNodesInteractions() @@ -57,23 +57,19 @@ const Add = ({ ${nodesReadOnly && '!cursor-not-allowed'} `} > - { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - }
- {t('workflow.panel.selectNextStep')} +
+ { + isParallel + ? t('workflow.common.addParallelNode') + : t('workflow.panel.selectNextStep') + } +
) - }, [branchName, t, nodesReadOnly]) + }, [t, nodesReadOnly, isParallel]) return ( { + return ( +
+ { + branchName && ( +
+ {branchName} +
+ ) + } + { + nextNodes.map(nextNode => ( + + )) + } + +
+ ) +} + +export default Container diff --git a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx index 261eb3fac7..d980eb284e 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx @@ -1,4 +1,5 @@ -import { memo } from 'react' +import { memo, useMemo } from 'react' +import { useTranslation } from 'react-i18next' import { getConnectedEdges, getOutgoers, @@ -8,13 +9,11 @@ import { import { useToolIcon } from '../../../../hooks' import BlockIcon from '../../../../block-icon' import type { - Branch, Node, } from '../../../../types' import { BlockEnum } from '../../../../types' -import Add from './add' -import Item from './item' import Line from './line' +import Container from './container' type NextStepProps = { selectedNode: Node @@ -22,15 +21,33 @@ type NextStepProps = { const NextStep = ({ selectedNode, }: NextStepProps) => { + const { t } = useTranslation() const data = selectedNode.data const toolIcon = useToolIcon(data) const store = useStoreApi() - const branches = data._targetBranches || [] + const branches = useMemo(() => { + return data._targetBranches || [] + }, [data]) const nodeWithBranches = data.type === BlockEnum.IfElse || data.type === BlockEnum.QuestionClassifier const edges = useEdges() const outgoers = getOutgoers(selectedNode as Node, store.getState().getNodes(), edges) const connectedEdges = getConnectedEdges([selectedNode] as Node[], edges).filter(edge => edge.source === selectedNode!.id) + const branchesOutgoers = useMemo(() => { + if (!branches?.length) + return [] + + return branches.map((branch) => { + const connected = connectedEdges.filter(edge => edge.sourceHandle === branch.id) + const nextNodes = connected.map(edge => outgoers.find(outgoer => outgoer.id === edge.target)!) + + return { + branch, + nextNodes, + } + }) + }, [branches, connectedEdges, outgoers]) + return (
@@ -39,59 +56,32 @@ const NextStep = ({ toolIcon={toolIcon} />
- -
+ item.nextNodes.length + 1) : [1]} + /> +
{ - !nodeWithBranches && !!outgoers.length && ( - - ) - } - { - !nodeWithBranches && !outgoers.length && ( - ) } { - !!branches?.length && nodeWithBranches && ( - branches.map((branch: Branch) => { - const connected = connectedEdges.find(edge => edge.sourceHandle === branch.id) - const target = outgoers.find(outgoer => outgoer.id === connected?.target) - + nodeWithBranches && ( + branchesOutgoers.map((item, index) => { return ( -
- { - connected && ( - - ) - } - { - !connected && ( - - ) - } -
+ ) }) ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx index b806de5684..db3748abd9 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx @@ -1,94 +1,82 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' -import { intersection } from 'lodash-es' +import Operator from './operator' import type { CommonNodeType, - OnSelectBlock, } from '@/app/components/workflow/types' import BlockIcon from '@/app/components/workflow/block-icon' -import BlockSelector from '@/app/components/workflow/block-selector' import { - useAvailableBlocks, useNodesInteractions, useNodesReadOnly, useToolIcon, } from '@/app/components/workflow/hooks' import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' type ItemProps = { nodeId: string sourceHandle: string - branchName?: string data: CommonNodeType } const Item = ({ nodeId, sourceHandle, - branchName, data, }: ItemProps) => { const { t } = useTranslation() - const { handleNodeChange } = useNodesInteractions() + const [open, setOpen] = useState(false) const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeSelect } = useNodesInteractions() const toolIcon = useToolIcon(data) - const { - availablePrevBlocks, - availableNextBlocks, - } = useAvailableBlocks(data.type, data.isInIteration) - const handleSelect = useCallback((type, toolDefaultValue) => { - handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) - }, [nodeId, sourceHandle, handleNodeChange]) - const renderTrigger = useCallback((open: boolean) => { - return ( - - ) - }, [t]) + const handleOpenChange = useCallback((v: boolean) => { + setOpen(v) + }, []) return (
- { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - } -
{data.title}
+
+ {data.title} +
{ !nodesReadOnly && ( - item !== data.type)} - /> + <> + +
+ +
+ ) }
diff --git a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx index b06a02c158..3a4430cb5d 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx @@ -1,56 +1,70 @@ import { memo } from 'react' type LineProps = { - linesNumber: number + list: number[] } const Line = ({ - linesNumber, + list, }: LineProps) => { - const svgHeight = linesNumber * 36 + (linesNumber - 1) * 12 + const listHeight = list.map((item) => { + return item * 36 + (item - 1) * 2 + 12 + 6 + }) + const processedList = listHeight.map((item, index) => { + if (index === 0) + return item + + return listHeight.slice(0, index).reduce((acc, cur) => acc + cur, 0) + item + }) + const processedListLength = processedList.length + const svgHeight = processedList[processedListLength - 1] + (processedListLength - 1) * 8 return ( { - Array(linesNumber).fill(0).map((_, index) => ( - - { - index === 0 && ( - <> + processedList.map((item, index) => { + const prevItem = index > 0 ? processedList[index - 1] : 0 + const space = prevItem + index * 8 + 16 + return ( + + { + index === 0 && ( + <> + + + + ) + } + { + index > 0 && ( - - - ) - } - { - index > 0 && ( - - ) - } - - - )) + ) + } + + + ) + }) } ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx new file mode 100644 index 0000000000..ad6c7abd0c --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx @@ -0,0 +1,129 @@ +import { + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiMoreFill } from '@remixicon/react' +import { intersection } from 'lodash-es' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import BlockSelector from '@/app/components/workflow/block-selector' +import { + useAvailableBlocks, + useNodesInteractions, +} from '@/app/components/workflow/hooks' +import type { + CommonNodeType, + OnSelectBlock, +} from '@/app/components/workflow/types' + +type ChangeItemProps = { + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const ChangeItem = ({ + data, + nodeId, + sourceHandle, +}: ChangeItemProps) => { + const { t } = useTranslation() + + const { handleNodeChange } = useNodesInteractions() + const { + availablePrevBlocks, + availableNextBlocks, + } = useAvailableBlocks(data.type, data.isInIteration) + + const handleSelect = useCallback((type, toolDefaultValue) => { + handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) + }, [nodeId, sourceHandle, handleNodeChange]) + + const renderTrigger = useCallback(() => { + return ( +
+ {t('workflow.panel.change')} +
+ ) + }, [t]) + + return ( + item !== data.type)} + /> + ) +} + +type OperatorProps = { + open: boolean + onOpenChange: (v: boolean) => void + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const Operator = ({ + open, + onOpenChange, + data, + nodeId, + sourceHandle, +}: OperatorProps) => { + const { t } = useTranslation() + const { + handleNodeDelete, + handleNodeDisconnect, + } = useNodesInteractions() + + return ( + + onOpenChange(!open)}> + + + +
+
+ +
handleNodeDisconnect(nodeId)} + > + {t('workflow.common.disconnect')} +
+
+
+
handleNodeDelete(nodeId)} + > + {t('common.operation.delete')} +
+
+
+
+
+ ) +} + +export default Operator diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index 56870f79d6..bcd03d6a5e 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -9,16 +9,22 @@ import { Handle, Position, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { BlockEnum } from '../../../types' import type { Node } from '../../../types' import BlockSelector from '../../../block-selector' import type { ToolDefaultValue } from '../../../block-selector/types' import { useAvailableBlocks, + useIsChatMode, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '../../../hooks' -import { useStore } from '../../../store' +import { + useStore, +} from '../../../store' +import Tooltip from '@/app/components/base/tooltip' type NodeHandleProps = { handleId: string @@ -38,9 +44,7 @@ export const NodeTargetHandle = memo(({ const { getNodesReadOnly } = useNodesReadOnly() const connected = data._connectedTargetHandleIds?.includes(handleId) const { availablePrevBlocks } = useAvailableBlocks(data.type, data.isInIteration) - const isConnectable = !!availablePrevBlocks.length && ( - !data.isIterationStart - ) + const isConnectable = !!availablePrevBlocks.length const handleOpenChange = useCallback((v: boolean) => { setOpen(v) @@ -112,12 +116,15 @@ export const NodeSourceHandle = memo(({ handleClassName, nodeSelectorClassName, }: NodeHandleProps) => { + const { t } = useTranslation() const notInitialWorkflow = useStore(s => s.notInitialWorkflow) const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { getNodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration) const isConnectable = !!availableNextBlocks.length + const isChatMode = useIsChatMode() + const { checkParallelLimit } = useWorkflow() const connected = data._connectedSourceHandleIds?.includes(handleId) const handleOpenChange = useCallback((v: boolean) => { @@ -125,9 +132,9 @@ export const NodeSourceHandle = memo(({ }, []) const handleHandleClick = useCallback((e: MouseEvent) => { e.stopPropagation() - if (!connected) + if (checkParallelLimit(id)) setOpen(v => !v) - }, [connected]) + }, [checkParallelLimit, id]) const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { handleNodeAdd( { @@ -142,12 +149,25 @@ export const NodeSourceHandle = memo(({ }, [handleNodeAdd, id, handleId]) useEffect(() => { - if (notInitialWorkflow && data.type === BlockEnum.Start) + if (notInitialWorkflow && data.type === BlockEnum.Start && !isChatMode) setOpen(true) - }, [notInitialWorkflow, data.type]) + }, [notInitialWorkflow, data.type, isChatMode]) return ( - <> + +
+ {t('workflow.common.parallelTip.click.title')} + {t('workflow.common.parallelTip.click.desc')} +
+
+ {t('workflow.common.parallelTip.drag.title')} + {t('workflow.common.parallelTip.drag.desc')} +
+
+ )} + > { - !connected && isConnectable && !getNodesReadOnly() && ( + isConnectable && !getNodesReadOnly() && ( - + ) }) NodeSourceHandle.displayName = 'NodeSourceHandle' diff --git a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx index 4c83bea8d6..a8e7a9aa11 100644 --- a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx @@ -28,8 +28,8 @@ const NodeResizer = ({ nodeId, nodeData, icon = , - minWidth = 272, - minHeight = 176, + minWidth = 258, + minHeight = 152, maxWidth, }: NodeResizerProps) => { const { handleNodeResize } = useNodesInteractions() diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 0b45c80888..bd5921c735 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -14,6 +14,7 @@ import { RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' +import { useTranslation } from 'react-i18next' import type { NodeProps } from '../../types' import { BlockEnum, @@ -43,6 +44,7 @@ const BaseNode: FC = ({ data, children, }) => { + const { t } = useTranslation() const nodeRef = useRef(null) const { nodesReadOnly } = useNodesReadOnly() const { handleNodeIterationChildSizeChange } = useNodeIterationInteractions() @@ -80,6 +82,7 @@ const BaseNode: FC = ({ className={cn( 'flex border-[2px] rounded-2xl', showSelectedBorder ? 'border-components-option-card-option-selected-border' : 'border-transparent', + !showSelectedBorder && data._inParallelHovering && 'border-workflow-block-border-highlight', )} ref={nodeRef} style={{ @@ -100,6 +103,13 @@ const BaseNode: FC = ({ data._isBundled && '!shadow-lg', )} > + { + data._inParallelHovering && ( +
+ {t('workflow.common.parallelRun')} +
+ ) + } { data._showAddVariablePopup && ( { ComparisonOperator.isNot, ComparisonOperator.empty, ComparisonOperator.notEmpty, - ComparisonOperator.regexMatch, ] case VarType.number: return [ diff --git a/web/app/components/workflow/nodes/iteration-start/constants.ts b/web/app/components/workflow/nodes/iteration-start/constants.ts new file mode 100644 index 0000000000..94e3ccbd90 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/constants.ts @@ -0,0 +1 @@ +export const CUSTOM_ITERATION_START_NODE = 'custom-iteration-start' diff --git a/web/app/components/workflow/nodes/iteration-start/default.ts b/web/app/components/workflow/nodes/iteration-start/default.ts new file mode 100644 index 0000000000..d98efa7ba2 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/default.ts @@ -0,0 +1,21 @@ +import type { NodeDefault } from '../../types' +import type { IterationStartNodeType } from './types' +import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' + +const nodeDefault: NodeDefault = { + defaultValue: {}, + getAvailablePrevNodes() { + return [] + }, + getAvailableNextNodes(isChatMode: boolean) { + const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + return nodes + }, + checkValid() { + return { + isValid: true, + } + }, +} + +export default nodeDefault diff --git a/web/app/components/workflow/nodes/iteration-start/index.tsx b/web/app/components/workflow/nodes/iteration-start/index.tsx new file mode 100644 index 0000000000..9d7ac1f905 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/index.tsx @@ -0,0 +1,42 @@ +import { memo } from 'react' +import { useTranslation } from 'react-i18next' +import type { NodeProps } from 'reactflow' +import { RiHome5Fill } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' +import { NodeSourceHandle } from '@/app/components/workflow/nodes/_base/components/node-handle' + +const IterationStartNode = ({ id, data }: NodeProps) => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+ +
+ ) +} + +export const IterationStartNodeDumb = () => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+
+ ) +} + +export default memo(IterationStartNode) diff --git a/web/app/components/workflow/nodes/iteration-start/types.ts b/web/app/components/workflow/nodes/iteration-start/types.ts new file mode 100644 index 0000000000..319cce0bc2 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/types.ts @@ -0,0 +1,3 @@ +import type { CommonNodeType } from '@/app/components/workflow/types' + +export type IterationStartNodeType = CommonNodeType diff --git a/web/app/components/workflow/nodes/iteration/add-block.tsx b/web/app/components/workflow/nodes/iteration/add-block.tsx index fd8480b7df..07e2b5daf0 100644 --- a/web/app/components/workflow/nodes/iteration/add-block.tsx +++ b/web/app/components/workflow/nodes/iteration/add-block.tsx @@ -2,87 +2,49 @@ import { memo, useCallback, } from 'react' -import produce from 'immer' import { RiAddLine, } from '@remixicon/react' -import { useStoreApi } from 'reactflow' import { useTranslation } from 'react-i18next' import { - generateNewNode, -} from '../../utils' -import { - WorkflowHistoryEvent, useAvailableBlocks, + useNodesInteractions, useNodesReadOnly, - useWorkflowHistory, } from '../../hooks' -import { NODES_INITIAL_DATA } from '../../constants' -import InsertBlock from './insert-block' import type { IterationNodeType } from './types' import cn from '@/utils/classnames' import BlockSelector from '@/app/components/workflow/block-selector' -import { IterationStart } from '@/app/components/base/icons/src/vender/workflow' import type { OnSelectBlock, } from '@/app/components/workflow/types' import { BlockEnum, } from '@/app/components/workflow/types' -import Tooltip from '@/app/components/base/tooltip' type AddBlockProps = { iterationNodeId: string iterationNodeData: IterationNodeType } const AddBlock = ({ - iterationNodeId, iterationNodeData, }: AddBlockProps) => { const { t } = useTranslation() - const store = useStoreApi() const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeAdd } = useNodesInteractions() const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, true) - const { availablePrevBlocks } = useAvailableBlocks(iterationNodeData.startNodeType, true) - const { saveStateToHistory } = useWorkflowHistory() const handleSelect = useCallback((type, toolDefaultValue) => { - const { - getNodes, - setNodes, - } = store.getState() - const nodes = getNodes() - const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ - data: { - ...NODES_INITIAL_DATA[type], - title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), - ...(toolDefaultValue || {}), - isIterationStart: true, - isInIteration: true, - iteration_id: iterationNodeId, + handleNodeAdd( + { + nodeType: type, + toolDefaultValue, }, - position: { - x: 117, - y: 85, + { + prevNodeId: iterationNodeData.start_node_id, + prevNodeSourceHandle: 'source', }, - zIndex: 1001, - parentId: iterationNodeId, - extent: 'parent', - }) - const newNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - if (node.id === iterationNodeId) { - node.data._children = [newNode.id] - node.data.start_node_id = newNode.id - node.data.startNodeType = newNode.data.type - } - }) - draft.push(newNode) - }) - setNodes(newNodes) - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [store, t, iterationNodeId, saveStateToHistory]) + ) + }, [handleNodeAdd, iterationNodeData.start_node_id]) const renderTriggerElement = useCallback((open: boolean) => { return ( @@ -98,35 +60,18 @@ const AddBlock = ({ }, [nodesReadOnly, t]) return ( -
- -
- -
-
+
- { - iterationNodeData.startNodeType && ( - - ) - }
- { - !iterationNodeData.startNodeType && ( - - ) - } +
) } diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 43f8a751ac..3afa52d06e 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -9,6 +9,7 @@ const nodeDefault: NodeDefault = { start_node_id: '', iterator_selector: [], output_selector: [], + _children: [], }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/iteration/insert-block.tsx b/web/app/components/workflow/nodes/iteration/insert-block.tsx deleted file mode 100644 index d041fe1c74..0000000000 --- a/web/app/components/workflow/nodes/iteration/insert-block.tsx +++ /dev/null @@ -1,61 +0,0 @@ -import { - memo, - useCallback, - useState, -} from 'react' -import { useNodesInteractions } from '../../hooks' -import type { - BlockEnum, - OnSelectBlock, -} from '../../types' -import BlockSelector from '../../block-selector' -import cn from '@/utils/classnames' - -type InsertBlockProps = { - startNodeId: string - availableBlocksTypes: BlockEnum[] -} -const InsertBlock = ({ - startNodeId, - availableBlocksTypes, -}: InsertBlockProps) => { - const [open, setOpen] = useState(false) - const { handleNodeAdd } = useNodesInteractions() - - const handleOpenChange = useCallback((v: boolean) => { - setOpen(v) - }, []) - const handleInsert = useCallback((nodeType, toolDefaultValue) => { - handleNodeAdd( - { - nodeType, - toolDefaultValue, - }, - { - nextNodeId: startNodeId, - nextNodeTargetHandle: 'target', - }, - ) - }, [startNodeId, handleNodeAdd]) - - return ( -
- 'hover:scale-125 transition-all'} - /> -
- ) -} - -export default memo(InsertBlock) diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index f4520402f3..48a005a261 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,6 +8,7 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' @@ -29,7 +30,7 @@ const Node: FC> = ({ return (
> = ({ size={2 / zoom} color='#E4E5E7' /> - + { + data._isCandidate && ( + + ) + } + { + data._children!.length === 1 && ( + + ) + }
) } diff --git a/web/app/components/workflow/nodes/iteration/use-interactions.ts b/web/app/components/workflow/nodes/iteration/use-interactions.ts index 219c8e731f..f8e3640cc4 100644 --- a/web/app/components/workflow/nodes/iteration/use-interactions.ts +++ b/web/app/components/workflow/nodes/iteration/use-interactions.ts @@ -11,6 +11,7 @@ import { ITERATION_PADDING, NODES_INITIAL_DATA, } from '../../constants' +import { CUSTOM_ITERATION_START_NODE } from '../iteration-start/constants' export const useNodeIterationInteractions = () => { const { t } = useTranslation() @@ -107,12 +108,12 @@ export const useNodeIterationInteractions = () => { const handleNodeIterationChildrenCopy = useCallback((nodeId: string, newNodeId: string) => { const { getNodes } = store.getState() const nodes = getNodes() - const childrenNodes = nodes.filter(n => n.parentId === nodeId) + const childrenNodes = nodes.filter(n => n.parentId === nodeId && n.type !== CUSTOM_ITERATION_START_NODE) return childrenNodes.map((child, index) => { const childNodeType = child.data.type as BlockEnum const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[childNodeType], ...child.data, @@ -121,6 +122,7 @@ export const useNodeIterationInteractions = () => { _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${childNodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${childNodeType}`), + iteration_id: newNodeId, }, position: child.position, positionAbsolute: child.positionAbsolute, diff --git a/web/app/components/workflow/operator/add-block.tsx b/web/app/components/workflow/operator/add-block.tsx index 48222cc528..388fbc053f 100644 --- a/web/app/components/workflow/operator/add-block.tsx +++ b/web/app/components/workflow/operator/add-block.tsx @@ -55,7 +55,7 @@ const AddBlock = ({ } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[type], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), diff --git a/web/app/components/workflow/operator/hooks.ts b/web/app/components/workflow/operator/hooks.ts index 5b14211497..edec10bda7 100644 --- a/web/app/components/workflow/operator/hooks.ts +++ b/web/app/components/workflow/operator/hooks.ts @@ -11,7 +11,7 @@ export const useOperator = () => { const { userProfile } = useAppContext() const handleAddNote = useCallback(() => { - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ type: CUSTOM_NOTE_NODE, data: { title: '', diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 54d3915a13..51a018bcb1 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -180,8 +180,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) const bodyParams = { @@ -317,11 +315,11 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, - onIterationNext: () => { + onIterationNext: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] + const iterations = tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! iterations.details!.push([]) handleUpdateChatList(produce(chatListRef.current, (draft) => { @@ -331,9 +329,10 @@ export const useChat = ( }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - tracing[tracing.length - 1] = { - ...iterations, + const iterationsIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationsIndex] = { + ...tracing[iterationsIndex], ...data, status: NodeRunningStatus.Succeeded, } as any @@ -341,67 +340,45 @@ export const useChat = ( const currentIndex = draft.length - 1 draft[currentIndex] = responseItem })) - - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - responseItem.workflowProcess!.tracing!.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + responseItem.workflowProcess!.tracing!.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, onNodeFinished: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration[currIteration.length - 1] = { - ...data, - status: NodeRunningStatus.Succeeded, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) - responseItem.workflowProcess!.tracing[currentIndex] = { - ...(responseItem.workflowProcess!.tracing[currentIndex].extras - ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } - : {}), - ...data, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + responseItem.workflowProcess!.tracing[currentIndex] = { + ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras + ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } + : {}), + ...data, + } as any + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, }, ) diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index b9b77fc4a3..331ef1c2f5 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -63,26 +63,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const formatNodeList = useCallback((list: NodeTracing[]) => { const allItems = list.reverse() const result: NodeTracing[] = [] - let iterationIndex = 0 allItems.forEach((item) => { const { node_type, execution_metadata } = item if (node_type !== BlockEnum.Iteration) { const isInIteration = !!execution_metadata?.iteration_id if (isInIteration) { - const iterationDetails = result[result.length - 1].details! - const currentIterationIndex = execution_metadata?.iteration_index - const isIterationFirstNode = iterationIndex !== currentIterationIndex || iterationDetails.length === 0 + const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) + const iterationDetails = iterationNode?.details + const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - if (isIterationFirstNode) { - iterationDetails!.push([item]) - iterationIndex = currentIterationIndex! + if (Array.isArray(iterationDetails)) { + if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) + iterationDetails[currentIterationIndex] = [item] + else + iterationDetails[currentIterationIndex].push(item) } - - else { - iterationDetails[iterationDetails.length - 1].push(item) - } - return } // not in iteration @@ -90,7 +86,6 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe return } - result.push({ ...item, details: [], diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index c833ea0342..4fc30f03df 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -1,10 +1,14 @@ 'use client' import type { FC } from 'react' -import React, { useCallback } from 'react' +import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { RiCloseLine } from '@remixicon/react' +import { + RiArrowRightSLine, + RiCloseLine, +} from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' -import NodePanel from './node' +import TracingPanel from './tracing-panel' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import type { NodeTracing } from '@/types/workflow' const i18nPrefix = 'workflow.singleRun' @@ -23,43 +27,67 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() + const [expandedIterations, setExpandedIterations] = useState>([]) + + const toggleIteration = useCallback((index: number) => { + setExpandedIterations(prev => ({ + ...prev, + [index]: !prev[index], + })) + }, []) const main = ( <> -
+
-
+
{t(`${i18nPrefix}.testRunIteration`)}
- +
-
+
-
{t(`${i18nPrefix}.back`)}
+
{t(`${i18nPrefix}.back`)}
{/* List */} -
+
{list.map((iteration, index) => ( -
-
-
{t(`${i18nPrefix}.iteration`)} {index + 1}
-
+
+
toggleIteration(index)} + > +
+
+ +
+ + {t(`${i18nPrefix}.iteration`)} {index + 1} + + +
-
- {iteration.map(node => ( - - ))} + {expandedIterations[index] &&
} +
+
))} diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index 66f996f13b..2e45290ddf 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -4,15 +4,17 @@ import type { FC } from 'react' import { useCallback, useEffect, useState } from 'react' import { RiArrowRightSLine, - RiCheckboxCircleLine, + RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' import Split from '../nodes/_base/components/split' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import type { NodeTracing } from '@/types/workflow' @@ -61,31 +63,38 @@ const NodePanel: FC = ({ return `${parseFloat((tokens / 1000000).toFixed(3))}M` } + const getCount = (iteration_curr_length: number | undefined, iteration_length: number) => { + if ((iteration_curr_length && iteration_curr_length < iteration_length) || !iteration_length) + return iteration_curr_length + + return iteration_length + } + useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration - const handleOnShowIterationDetail = (e: React.MouseEvent) => { + const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || []) } return ( -
-
+
+
setCollapseState(!collapseState)} > {!hideProcessDetail && ( @@ -93,23 +102,23 @@ const NodePanel: FC = ({
{nodeInfo.title}
{nodeInfo.status !== 'running' && !hideInfo && ( -
{`${getTime(nodeInfo.elapsed_time || 0)} · ${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens`}
+
{nodeInfo.execution_metadata?.total_tokens ? `${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens · ` : ''}{`${getTime(nodeInfo.elapsed_time || 0)}`}
)} {nodeInfo.status === 'succeeded' && ( - + )} {nodeInfo.status === 'failed' && ( - + )} {nodeInfo.status === 'stopped' && ( )} {nodeInfo.status === 'running' && ( -
+
Running
@@ -120,25 +129,27 @@ const NodePanel: FC = ({ {/* The nav to the iteration detail */} {isIterationNode && !notShowIterationNav && (
-
-
{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length || nodeInfo.details?.length })}
+
+
)} -
+
{nodeInfo.status === 'stopped' && (
{t('workflow.tracing.stopBy', { user: nodeInfo.created_by ? nodeInfo.created_by.name : 'N/A' })}
)} diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index c80bcb41db..7684dc84b6 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -1,24 +1,266 @@ 'use client' import type { FC } from 'react' +import +React, +{ + useCallback, + useState, +} from 'react' +import cn from 'classnames' +import { + RiArrowDownSLine, + RiMenu4Line, +} from '@remixicon/react' import NodePanel from './node' +import { + BlockEnum, +} from '@/app/components/workflow/types' import type { NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] - onShowIterationDetail: (detail: NodeTracing[][]) => void + onShowIterationDetail?: (detail: NodeTracing[][]) => void + className?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean } -const TracingPanel: FC = ({ list, onShowIterationDetail }) => { +type TracingNodeProps = { + id: string + uniqueId: string + isParallel: boolean + data: NodeTracing | null + children: TracingNodeProps[] + parallelTitle?: string + branchTitle?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean +} + +function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { + const rootNodes: TracingNodeProps[] = [] + const parallelStacks: { [key: string]: TracingNodeProps } = {} + const levelCounts: { [key: string]: number } = {} + const parallelChildCounts: { [key: string]: Set } = {} + let uniqueIdCounter = 0 + const getUniqueId = () => { + uniqueIdCounter++ + return `unique-${uniqueIdCounter}` + } + + const getParallelTitle = (parentId: string | null): string => { + const levelKey = parentId || 'root' + if (!levelCounts[levelKey]) + levelCounts[levelKey] = 0 + + levelCounts[levelKey]++ + + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + return `PARALLEL-${levelNumber}${letter}` + } + + const getBranchTitle = (parentId: string | null, branchNum: number): string => { + const levelKey = parentId || 'root' + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + const branchLetter = String.fromCharCode(64 + branchNum) + return `BRANCH-${levelNumber}${letter}-${branchLetter}` + } + + // Count parallel children (for figuring out if we need to use letters) + for (const node of nodes) { + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + + if (parallel_id) { + const parentKey = parent_parallel_id || 'root' + if (!parallelChildCounts[parentKey]) + parallelChildCounts[parentKey] = new Set() + + parallelChildCounts[parentKey].add(parallel_id) + } + } + + for (const node of nodes) { + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_start_node_id = node.parallel_start_node_id ?? node.execution_metadata?.parallel_start_node_id ?? null + const parent_parallel_start_node_id = node.parent_parallel_start_node_id ?? node.execution_metadata?.parent_parallel_start_node_id ?? null + + if (!parallel_id || node.node_type === BlockEnum.End) { + rootNodes.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + }) + } + else { + if (!parallelStacks[parallel_id]) { + const newParallelGroup: TracingNodeProps = { + id: parallel_id, + uniqueId: getUniqueId(), + isParallel: true, + data: null, + children: [], + parallelTitle: '', + } + parallelStacks[parallel_id] = newParallelGroup + + if (parent_parallel_id && parallelStacks[parent_parallel_id]) { + const sameBranchIndex = parallelStacks[parent_parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parent_parallel_start_node_id || c.data?.parallel_start_node_id === parent_parallel_start_node_id, + ) + parallelStacks[parent_parallel_id].children.splice(sameBranchIndex + 1, 0, newParallelGroup) + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + } + else { + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + rootNodes.push(newParallelGroup) + } + } + const branchTitle = parallel_start_node_id === node.node_id ? getBranchTitle(parent_parallel_id, parallelStacks[parallel_id].children.length + 1) : '' + if (branchTitle) { + parallelStacks[parallel_id].children.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + else { + let sameBranchIndex = parallelStacks[parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parallel_start_node_id || c.data?.parallel_start_node_id === parallel_start_node_id, + ) + if (parallelStacks[parallel_id].children[sameBranchIndex + 1]?.isParallel) + sameBranchIndex++ + + parallelStacks[parallel_id].children.splice(sameBranchIndex + 1, 0, { + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + } + } + + return rootNodes +} + +const TracingPanel: FC = ({ + list, + onShowIterationDetail, + className, + hideNodeInfo = false, + hideNodeProcessDetail = false, +}) => { + const treeNodes = buildLogTree(list) + const [collapsedNodes, setCollapsedNodes] = useState>(new Set()) + const [hoveredParallel, setHoveredParallel] = useState(null) + + const toggleCollapse = (id: string) => { + setCollapsedNodes((prev) => { + const newSet = new Set(prev) + if (newSet.has(id)) + newSet.delete(id) + + else + newSet.add(id) + + return newSet + }) + } + + const handleParallelMouseEnter = useCallback((id: string) => { + setHoveredParallel(id) + }, []) + + const handleParallelMouseLeave = useCallback((e: React.MouseEvent) => { + const relatedTarget = e.relatedTarget as Element | null + if (relatedTarget && 'closest' in relatedTarget) { + const closestParallel = relatedTarget.closest('[data-parallel-id]') + if (closestParallel) + setHoveredParallel(closestParallel.getAttribute('data-parallel-id')) + + else + setHoveredParallel(null) + } + else { + setHoveredParallel(null) + } + }, []) + + const renderNode = (node: TracingNodeProps) => { + if (node.isParallel) { + const isCollapsed = collapsedNodes.has(node.id) + const isHovered = hoveredParallel === node.id + return ( +
handleParallelMouseEnter(node.id)} + onMouseLeave={handleParallelMouseLeave} + > +
+ +
+ {node.parallelTitle} +
+
+
+
+
+ {node.children.map(renderNode)} +
+
+ ) + } + else { + const isHovered = hoveredParallel === node.id + return ( +
+
+ {node.branchTitle} +
+ +
+ ) + } + } + return ( -
- {list.map(node => ( - - ))} +
+ {treeNodes.map(renderNode)}
) } diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 2e5e774191..853d0c5934 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -162,6 +162,8 @@ type Shape = { setControlPromptEditorRerenderKey: (controlPromptEditorRerenderKey: number) => void showImportDSLModal: boolean setShowImportDSLModal: (showImportDSLModal: boolean) => void + showTips: string + setShowTips: (showTips: string) => void } export const createWorkflowStore = () => { @@ -262,6 +264,8 @@ export const createWorkflowStore = () => { setControlPromptEditorRerenderKey: controlPromptEditorRerenderKey => set(() => ({ controlPromptEditorRerenderKey })), showImportDSLModal: false, setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), + showTips: '', + setShowTips: showTips => set(() => ({ showTips })), })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 12957aa0dd..797c2dbd85 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -26,6 +26,7 @@ export enum BlockEnum { Tool = 'tool', ParameterExtractor = 'parameter-extractor', Iteration = 'iteration', + IterationStart = 'iteration-start', Assigner = 'assigner', // is now named as VariableAssigner } @@ -54,7 +55,7 @@ export type CommonNodeType = { _holdAddVariablePopup?: boolean _iterationLength?: number _iterationIndex?: number - isIterationStart?: boolean + _inParallelHovering?: boolean isInIteration?: boolean iteration_id?: string selected?: boolean diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index e73485f546..91656e3bbc 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -1,12 +1,15 @@ import { Position, getConnectedEdges, + getIncomers, getOutgoers, } from 'reactflow' import dagre from '@dagrejs/dagre' import { v4 as uuid4 } from 'uuid' import { cloneDeep, + groupBy, + isEqual, uniqBy, } from 'lodash-es' import type { @@ -19,14 +22,17 @@ import type { import { BlockEnum } from './types' import { CUSTOM_NODE, + ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION, } from './constants' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import type { QuestionClassifierNodeType } from './nodes/question-classifier/types' import type { IfElseNodeType } from './nodes/if-else/types' import { branchNameCorrect } from './nodes/if-else/utils' import type { ToolNodeType } from './nodes/tool/types' +import type { IterationNodeType } from './nodes/iteration/types' import { CollectionType } from '@/app/components/tools/types' import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' @@ -84,9 +90,130 @@ const getCycleEdges = (nodes: Node[], edges: Edge[]) => { return cycleEdges } +export function getIterationStartNode(iterationId: string): Node { + return generateNewNode({ + id: `${iterationId}start`, + type: CUSTOM_ITERATION_START_NODE, + data: { + title: '', + desc: '', + type: BlockEnum.IterationStart, + isInIteration: true, + }, + position: { + x: 24, + y: 68, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + parentId: iterationId, + selectable: false, + draggable: false, + }).newNode +} + +export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }): { + newNode: Node + newIterationStartNode?: Node +} { + const newNode = { + id: id || `${Date.now()}`, + type: type || CUSTOM_NODE, + data, + position, + targetPosition: Position.Left, + sourcePosition: Position.Right, + zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, + ...rest, + } as Node + + if (data.type === BlockEnum.Iteration) { + const newIterationStartNode = getIterationStartNode(newNode.id); + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id; + (newNode.data as IterationNodeType)._children = [newIterationStartNode.id] + return { + newNode, + newIterationStartNode, + } + } + + return { + newNode, + } +} + +export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { + const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration) + + if (!hasIterationNode) { + return { + nodes, + edges, + } + } + const nodesMap = nodes.reduce((prev, next) => { + prev[next.id] = next + return prev + }, {} as Record) + const iterationNodesWithStartNode = [] + const iterationNodesWithoutStartNode = [] + + for (let i = 0; i < nodes.length; i++) { + const currentNode = nodes[i] as Node + + if (currentNode.data.type === BlockEnum.Iteration) { + if (currentNode.data.start_node_id) { + if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE) + iterationNodesWithStartNode.push(currentNode) + } + else { + iterationNodesWithoutStartNode.push(currentNode) + } + } + } + const newIterationStartNodesMap = {} as Record + const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => { + const newNode = getIterationStartNode(iterationNode.id) + newNode.id = newNode.id + index + newIterationStartNodesMap[iterationNode.id] = newNode + return newNode + }) + const newEdges = iterationNodesWithStartNode.map((iterationNode) => { + const newNode = newIterationStartNodesMap[iterationNode.id] + const startNode = nodesMap[iterationNode.data.start_node_id] + const source = newNode.id + const sourceHandle = 'source' + const target = startNode.id + const targetHandle = 'target' + return { + id: `${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'custom', + source, + sourceHandle, + target, + targetHandle, + data: { + sourceType: newNode.data.type, + targetType: startNode.data.type, + isInIteration: true, + iteration_id: startNode.parentId, + _connectedNodeIsSelected: true, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + } + }) + nodes.forEach((node) => { + if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id]) + (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id + }) + + return { + nodes: [...nodes, ...newIterationStartNodes], + edges: [...edges, ...newEdges], + } +} + export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) const firstNode = nodes[0] if (!firstNode?.position) { @@ -148,8 +275,7 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { } export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) let selectedNode: Node | null = null const nodesMap = nodes.reduce((acc, node) => { acc[node.id] = node @@ -291,19 +417,6 @@ export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSo return nodesConnectedSourceOrTargetHandleIdsMap } -export const generateNewNode = ({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }) => { - return { - id: id || `${Date.now()}`, - type: type || CUSTOM_NODE, - data, - position, - targetPosition: Position.Left, - sourcePosition: Position.Right, - zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, - ...rest, - } as Node -} - export const genNewNodeTitleFromOld = (oldTitle: string) => { const regex = /^(.+?)\s*\((\d+)\)\s*$/ const match = oldTitle.match(regex) @@ -479,3 +592,167 @@ export const variableTransformer = (v: ValueSelector | string) => { return `{{#${v.join('.')}#}}` } + +type ParallelInfoItem = { + parallelNodeId: string + depth: number + isBranch?: boolean +} +type NodeParallelInfo = { + parallelNodeId: string + edgeHandleId: string + depth: number +} +type NodeHandle = { + node: Node + handle: string +} +type NodeStreamInfo = { + upstreamNodes: Set + downstreamEdges: Set +} +export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { + let startNode + + if (parentNodeId) { + const parentNode = nodes.find(node => node.id === parentNodeId) + if (!parentNode) + throw new Error('Parent node not found') + + startNode = nodes.find(node => node.id === (parentNode.data as IterationNodeType).start_node_id) + } + else { + startNode = nodes.find(node => node.data.type === BlockEnum.Start) + } + if (!startNode) + throw new Error('Start node not found') + + const parallelList = [] as ParallelInfoItem[] + const nextNodeHandles = [{ node: startNode, handle: 'source' }] + let hasAbnormalEdges = false + + const traverse = (firstNodeHandle: NodeHandle) => { + const nodeEdgesSet = {} as Record> + const totalEdgesSet = new Set() + const nextHandles = [firstNodeHandle] + const streamInfo = {} as Record + const parallelListItem = { + parallelNodeId: '', + depth: 0, + } as ParallelInfoItem + const nodeParallelInfoMap = {} as Record + nodeParallelInfoMap[firstNodeHandle.node.id] = { + parallelNodeId: '', + edgeHandleId: '', + depth: 0, + } + + while (nextHandles.length) { + const currentNodeHandle = nextHandles.shift()! + const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle + const currentNodeHandleKey = currentNode.id + const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) + const connectedEdgesLength = connectedEdges.length + const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) + const incomers = getIncomers(currentNode, nodes, edges) + + if (!streamInfo[currentNodeHandleKey]) { + streamInfo[currentNodeHandleKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { + const newSet = new Set() + for (const item of totalEdgesSet) { + if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) + newSet.add(item) + } + if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + nextNodeHandles.push({ node: currentNode, handle: currentHandle }) + break + } + } + + if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + + outgoers.forEach((outgoer) => { + const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) + const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') + const incomers = getIncomers(outgoer, nodes, edges) + + if (outgoers.length > 1 && incomers.length > 1) + hasAbnormalEdges = true + + Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { + nextHandles.push({ node: outgoer, handle: sourceHandle }) + }) + if (!outgoerConnectedEdges.length) + nextHandles.push({ node: outgoer, handle: 'source' }) + + const outgoerKey = outgoer.id + if (!nodeEdgesSet[outgoerKey]) + nodeEdgesSet[outgoerKey] = new Set() + + if (nodeEdgesSet[currentNodeHandleKey]) { + for (const item of nodeEdgesSet[currentNodeHandleKey]) + nodeEdgesSet[outgoerKey].add(item) + } + + if (!streamInfo[outgoerKey]) { + streamInfo[outgoerKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (!nodeParallelInfoMap[outgoer.id]) { + nodeParallelInfoMap[outgoer.id] = { + ...nodeParallelInfoMap[currentNode.id], + } + } + + if (connectedEdgesLength > 1) { + const edge = connectedEdges.find(edge => edge.target === outgoer.id)! + nodeEdgesSet[outgoerKey].add(edge.id) + totalEdgesSet.add(edge.id) + + streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) + streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) + + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[item].downstreamEdges.add(edge.id) + + if (!parallelListItem.parallelNodeId) + parallelListItem.parallelNodeId = currentNode.id + + const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 + const currentDepth = nodeParallelInfoMap[outgoer.id].depth + + nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) + } + else { + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[outgoerKey].upstreamNodes.add(item) + + nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth + } + }) + } + + parallelList.push(parallelListItem) + } + + while (nextNodeHandles.length) { + const nodeHandle = nextNodeHandles.shift()! + traverse(nodeHandle) + } + + return { + parallelList, + hasAbnormalEdges, + } +} diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 0d924ba16c..b83d213cb8 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Overwrite and Import', importFailure: 'Import failure', importSuccess: 'Import success', + parallelRun: 'Parallel Run', + parallelTip: { + click: { + title: 'Click', + desc: ' to add', + }, + drag: { + title: 'Drag', + desc: ' to connect', + }, + limit: 'Parallelism is limited to {{num}} branches.', + depthLimit: 'Parallel nesting layer limit of {{num}} layers', + }, + disconnect: 'Disconnect', + jumpToNode: 'Jump to this node', + addParallelNode: 'Add Parallel Node', }, env: { envPanelTitle: 'Environment Variables', @@ -412,7 +428,6 @@ const translation = { 'not empty': 'is not empty', 'null': 'is null', 'not null': 'is not null', - 'regex match': 'regex match', }, enterValue: 'Enter value', addCondition: 'Add Condition', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 0f00b117c1..39311649f4 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: '覆盖并导入', importFailure: '导入失败', importSuccess: '导入成功', + parallelRun: '并行运行', + parallelTip: { + click: { + title: '点击', + desc: '添加节点', + }, + drag: { + title: '拖拽', + desc: '连接节点', + }, + limit: '并行分支限制为 {{num}} 个', + depthLimit: '并行嵌套层数限制 {{num}} 层', + }, + disconnect: '断开连接', + jumpToNode: '跳转到节点', + addParallelNode: '添加并行节点', }, env: { envPanelTitle: '环境变量', @@ -412,7 +428,6 @@ const translation = { 'not empty': '不为空', 'null': '空', 'not null': '不为空', - 'regex match': '正则匹配', }, enterValue: '输入值', addCondition: '添加条件', diff --git a/web/package.json b/web/package.json index b6275f20ae..374286f8f7 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.7.3", + "version": "0.8.0", "private": true, "engines": { "node": ">=18.17.0" diff --git a/web/service/base.ts b/web/service/base.ts index 8f1b22bc1b..83389d8be8 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -8,6 +8,8 @@ import type { IterationStartedResponse, NodeFinishedResponse, NodeStartedResponse, + ParallelBranchFinishedResponse, + ParallelBranchStartedResponse, TextChunkResponse, TextReplaceResponse, WorkflowFinishedResponse, @@ -59,6 +61,8 @@ export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void +export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void +export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void export type IOnTextChunk = (textChunk: TextChunkResponse) => void export type IOnTTSChunk = (messageId: string, audioStr: string, audioType?: string) => void export type IOnTTSEnd = (messageId: string, audioStr: string, audioType?: string) => void @@ -86,6 +90,8 @@ export type IOtherOptions = { onIterationStart?: IOnIterationStarted onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onParallelBranchStarted?: IOnParallelBranchStarted + onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk onTTSChunk?: IOnTTSChunk onTTSEnd?: IOnTTSEnd @@ -139,6 +145,8 @@ const handleStream = ( onIterationStart?: IOnIterationStarted, onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onParallelBranchStarted?: IOnParallelBranchStarted, + onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, onTTSChunk?: IOnTTSChunk, onTTSEnd?: IOnTTSEnd, @@ -228,6 +236,12 @@ const handleStream = ( else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'parallel_branch_started') { + onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) + } + else if (bufferObj.event === 'parallel_branch_finished') { + onParallelBranchFinished?.(bufferObj as ParallelBranchFinishedResponse) + } else if (bufferObj.event === 'text_chunk') { onTextChunk?.(bufferObj as TextChunkResponse) } @@ -488,6 +502,8 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onParallelBranchStarted, + onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, @@ -544,7 +560,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/themes/dark.css b/web/themes/dark.css index 8aab0f5fbb..8d77329b5a 100644 --- a/web/themes/dark.css +++ b/web/themes/dark.css @@ -316,6 +316,7 @@ html[data-theme="dark"] { --color-workflow-block-border: #FFFFFF14; --color-workflow-block-parma-bg: #FFFFFF0D; --color-workflow-block-bg: #27272B; + --color-workflow-block-border-highlight: #C8CEDA33; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #1D1D20; diff --git a/web/themes/light.css b/web/themes/light.css index 5a5ef63769..89303c250e 100644 --- a/web/themes/light.css +++ b/web/themes/light.css @@ -316,6 +316,7 @@ html[data-theme="light"] { --color-workflow-block-border: #FFFFFF; --color-workflow-block-parma-bg: #F2F4F7; --color-workflow-block-bg: #FCFCFD; + --color-workflow-block-border-highlight: #155AEF24; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #F2F4F7; diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index 9178d23f24..643c96d1a1 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -316,6 +316,7 @@ const vars = { 'workflow-block-border': 'var(--color-workflow-block-border)', 'workflow-block-parma-bg': 'var(--color-workflow-block-parma-bg)', 'workflow-block-bg': 'var(--color-workflow-block-bg)', + 'workflow-block-border-highlight': 'var(--color-workflow-block-border-highlight)', 'workflow-canvas-workflow-dot-color': 'var(--color-workflow-canvas-workflow-dot-color)', 'workflow-canvas-workflow-bg': 'var(--color-workflow-canvas-workflow-bg)', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index e6b001bd79..dbf2b3e587 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -26,9 +26,14 @@ export type NodeTracing = { currency: string iteration_id?: string iteration_index?: number + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } metadata: { iterator_length: number + iterator_index: number } created_at: number created_by: { @@ -40,6 +45,10 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } export type FetchWorkflowDraftResponse = { @@ -109,6 +118,7 @@ export type NodeStartedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -125,6 +135,7 @@ export type NodeFinishedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -138,6 +149,10 @@ export type NodeFinishedResponse = { total_tokens: number total_price: number currency: string + parallel_id?: string + parallel_start_node_id?: string + iteration_index?: number + iteration_id?: string } created_at: number } @@ -152,6 +167,8 @@ export type IterationStartedResponse = { node_id: string metadata: { iterator_length: number + iteration_id: string + iteration_index: number } created_at: number extras?: any @@ -169,6 +186,9 @@ export type IterationNextResponse = { output: any extras?: any created_at: number + execution_metadata: { + parallel_id?: string + } } } @@ -184,6 +204,39 @@ export type IterationFinishedResponse = { status: string created_at: number error: string + execution_metadata: { + parallel_id?: string + } + } +} + +export type ParallelBranchStartedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + created_at: number + } +} + +export type ParallelBranchFinishedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + status: string + created_at: number + error: string } }