From d1505b15c40aa7bfe71bdc6f07efacc885985b34 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 5 Nov 2024 10:32:49 +0800 Subject: [PATCH] feat: Iteration node support parallel mode (#9493) --- .../advanced_chat/generate_task_pipeline.py | 3 +- .../apps/workflow/generate_task_pipeline.py | 3 +- api/core/app/apps/workflow_app_runner.py | 35 ++ api/core/app/entities/queue_entities.py | 37 +- api/core/app/entities/task_entities.py | 2 + .../task_pipeline/workflow_cycle_manage.py | 28 +- api/core/workflow/entities/node_entities.py | 1 + .../workflow/graph_engine/entities/event.py | 7 + .../workflow/graph_engine/graph_engine.py | 11 + api/core/workflow/nodes/iteration/entities.py | 10 + .../nodes/iteration/iteration_node.py | 412 ++++++++++++---- .../nodes/iteration/test_iteration.py | 449 +++++++++++++++++- web/app/components/base/select/index.tsx | 2 +- web/app/components/workflow/constants.ts | 4 +- .../workflow/hooks/use-nodes-interactions.ts | 5 + .../workflow/hooks/use-workflow-run.ts | 102 +++- .../workflow/nodes/_base/components/field.tsx | 6 +- .../components/workflow/nodes/_base/node.tsx | 24 +- .../workflow/nodes/iteration/default.ts | 39 +- .../workflow/nodes/iteration/node.tsx | 15 +- .../workflow/nodes/iteration/panel.tsx | 59 ++- .../workflow/nodes/iteration/types.ts | 5 + .../workflow/nodes/iteration/use-config.ts | 25 +- .../workflow/panel/debug-and-preview/hooks.ts | 12 +- web/app/components/workflow/run/index.tsx | 77 ++- .../workflow/run/iteration-result-panel.tsx | 20 +- web/app/components/workflow/run/node.tsx | 16 +- web/app/components/workflow/store.ts | 10 + web/app/components/workflow/types.ts | 6 +- web/app/components/workflow/utils.ts | 11 +- web/i18n/en-US/workflow.ts | 17 + web/i18n/zh-Hans/workflow.ts | 17 + web/types/workflow.ts | 5 + 33 files changed, 1283 insertions(+), 192 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e4cb3f8527..1fc7ffe2c7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 419a5da806..d119d94a61 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ca23bbdd47..9a01e8a253 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -9,6 +9,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner): error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) elif isinstance(event, NodeRunStreamChunkEvent): self._publish_event( QueueTextChunkEvent( @@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner): index=event.index, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc43baf8a5..f1542ec5d8 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" - + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class QueueNodeSucceededEvent(AppQueueEvent): @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): error: Optional[str] = None +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent): inputs: Optional[dict[str, Any]] = None process_data: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 4b5f4716ed..7e9aad54be 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse): parent_parallel_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse): extras: dict = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2abee5bef5..b89edf9079 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -251,6 +253,12 @@ class WorkflowCycleManage: workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value workflow_node_execution.created_by_role = workflow_run.created_by_role workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) session.add(workflow_node_execution) @@ -305,7 +313,9 @@ class WorkflowCycleManage: return workflow_node_execution - def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed( + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event @@ -318,16 +328,19 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, + WorkflowNodeExecution.execution_metadata: execution_metadata, } ) @@ -342,6 +355,7 @@ class WorkflowCycleManage: workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) @@ -448,6 +462,7 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, ), ) @@ -464,7 +479,7 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: @@ -608,6 +623,7 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, ), ) @@ -633,7 +649,9 @@ class WorkflowCycleManage: created_at=int(time.time()), extras={}, inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0131bb342b..7e10cddc71 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum): PARALLEL_START_NODE_ID = "parallel_start_node_id" PARENT_PARALLEL_ID = "parent_parallel_id" PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 86d89e0a32..bacea191dd 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent): class NodeRunStartedEvent(BaseNodeEvent): predecessor_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None """predecessor node id""" @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeInIterationFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + ########################################### # Parallel Branch Events ########################################### @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8f58af00ef..f07ad4de11 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait +from copy import copy, deepcopy from typing import Any, Optional from flask import Flask, current_app @@ -724,6 +725,16 @@ class GraphEngine: """ return time.perf_counter() - start_at > max_execution_time + def create_copy(self): + """ + create a graph engine copy + :return: with a new variable pool instance of graph engine + """ + new_instance = copy(self) + new_instance.graph_runtime_state = copy(self.graph_runtime_state) + new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) + return new_instance + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 4afc870e50..ebcb6f82fb 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Optional from pydantic import Field @@ -5,6 +6,12 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +class ErrorHandleMode(str, Enum): + TERMINATED = "terminated" + CONTINUE_ON_ERROR = "continue-on-error" + REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" + + class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData): parent_loop_id: Optional[str] = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + is_parallel: bool = False # open the parallel mode or not + parallel_nums: int = 10 # the numbers of parallel + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error class IterationStartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index af79da9215..d121b0530a 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,12 +1,20 @@ import logging +import uuid from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, wait from datetime import datetime, timezone -from typing import Any, cast +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, Optional, cast + +from flask import Flask, current_app from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.variables import IntegerSegment -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeRunResult, +) +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.workflow.graph_engine.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "iteration", + "config": { + "is_parallel": False, + "parallel_nums": 10, + "error_handle_mode": ErrorHandleMode.TERMINATED.value, + }, + } + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine - from core.workflow.graph_engine.graph_engine import GraphEngine + from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool graph_engine = GraphEngine( tenant_id=self.tenant_id, @@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]): index=0, pre_iteration_output=None, ) - outputs: list[Any] = [] try: - for _ in range(len(iterator_list_value)): - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id - - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): + if self.node_data.is_parallel: + futures: list[Future] = [] + q = Queue() + thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + for index, item in enumerate(iterator_list_value): + future: Future = thread_pool.submit( + self._run_single_iter_parallel, + current_app._get_current_object(), + q, + iterator_list_value, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + index, + item, + ) + future.add_done_callback(thread_pool.task_done_callback) + futures.append(future) + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + if isinstance(event, IterationRunNextEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + yield event + if isinstance(event, RunCompletedEvent): + q.put(None) + for f in futures: + if not f.done(): + f.cancel() + yield event + if isinstance(event, IterationRunFailedEvent): + q.put(None) + yield event + except Empty: continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(index_variable, IntegerSegment): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Invalid index variable type: {type(index_variable)}", - ) - ) - return - metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value - event.route_node_state.node_run_result.metadata = metadata - - yield event - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - ) - ) - return - else: - event = cast(InNodeEvent, event) - yield event - - # append to iteration output variable list - current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) - if current_iteration_output_variable is None: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Iteration output variable {self.node_data.output_selector} not found", - ) + # wait all threads + wait(futures) + else: + for _ in range(len(iterator_list_value)): + yield from self._run_single_iter( + iterator_list_value, + variable_pool, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, ) - return - current_iteration_output = current_iteration_output_variable.to_object() - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove([node_id]) - - # move to next iteration - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = current_index_variable.value + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output), - ) - yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]): } return variable_mapping + + def _handle_event_metadata( + self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str + ) -> NodeRunStartedEvent | BaseNodeEvent: + """ + add iteration metadata to event. + """ + if not isinstance(event, BaseNodeEvent): + return event + if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + event.parallel_mode_run_id = parallel_mode_run_id + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + if self.node_data.is_parallel: + metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + else: + metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + event.route_node_state.node_run_result.metadata = metadata + return event + + def _run_single_iter( + self, + iterator_list_value: list[str], + variable_pool: VariablePool, + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + parallel_mode_run_id: Optional[str] = None, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration + """ + try: + rst = graph_engine.run() + # get current iteration index + current_index = variable_pool.get([self.node_id, "index"]).value + next_index = int(current_index) + 1 + + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + if isinstance(event, NodeRunFailedEvent): + if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + outputs.insert(current_index, None) + variable_pool.add([self.node_id, "index"], next_index) + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield metadata_event + + current_iteration_output = variable_pool.get(self.node_data.output_selector).value + outputs.insert(current_index, current_iteration_output) + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # move to next iteration + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + ) + + except Exception as e: + logger.exception(f"Iteration run failed:{str(e)}") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + + def _run_single_iter_parallel( + self, + flask_app: Flask, + q: Queue, + iterator_list_value: list[str], + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + index: int, + item: Any, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration in parallel mode + """ + with flask_app.app_context(): + parallel_mode_run_id = uuid.uuid4().hex + graph_engine_copy = graph_engine.create_copy() + variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool + variable_pool_copy.add([self.node_id, "index"], index) + variable_pool_copy.add([self.node_id, "item"], item) + for event in self._run_single_iter( + iterator_list_value=iterator_list_value, + variable_pool=variable_pool_copy, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine_copy, + iteration_graph=iteration_graph, + parallel_mode_run_id=parallel_mode_run_id, + ): + q.put(event) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index d755faee8a..29bd4d6c6c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.enums import UserFrom @@ -185,8 +186,6 @@ def test_run(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() @@ -404,18 +403,458 @@ def test_run_parallel(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() count = 0 for item in result: - # print(type(item), item) count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert count == 32 + + +def test_iteration_run_in_parallel_mode(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + parallel_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + sequential_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + parallel_result = parallel_iteration_node._run() + sequential_result = sequential_iteration_node._run() + assert parallel_iteration_node.node_data.parallel_nums == 10 + assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + count = 0 + parallel_arr = [] + sequential_arr = [] + for item in parallel_result: + count += 1 + parallel_arr.append(item) + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 32 + + for item in sequential_result: + sequential_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 64 + + +def test_iteration_run_error_handle(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "iteration-start", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "tt", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "tt2", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt2", "output"], + "output_type": "array[string]", + "start_node_id": "if-else", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1.split(arg2) }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, + ], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + ], + }, + "id": "tt2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "1", + "variable_selector": ["iteration-1", "item"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["1", "1"]) + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + }, + ) + # execute continue on error node + result = iteration_node._run() + result_arr = [] + count = 0 + for item in result: + result_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": [None, None]} + + assert count == 14 + # execute remove abnormal output + iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + result = iteration_node._run() + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": []} + assert count == 14 diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ee5cee977b..c70cf24661 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -125,7 +125,7 @@ const Select: FC = ({ - {filteredItems.length > 0 && ( + {(filteredItems.length > 0 && open) && ( {filteredItems.map((item: Item) => ( { newNode.data.isInIteration = true newNode.data.iteration_id = prevNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX + if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) { + const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId) + const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data + iterNodeData._isShowTips = true + } } const newEdge: Edge = { diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 0bbb1adab8..26654ef71e 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -14,6 +14,7 @@ import { NodeRunningStatus, WorkflowRunningStatus, } from '../types' +import { DEFAULT_ITER_TIMES } from '../constants' import { useWorkflowUpdate } from './use-workflow-interactions' import { useStore as useAppStore } from '@/app/components/app/store' import type { IOtherOptions } from '@/service/base' @@ -170,11 +171,13 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterParallelLogMap, } = workflowStore.getState() const { edges, setEdges, } = store.getState() + setIterParallelLogMap(new Map()) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id draft.result = { @@ -244,6 +247,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -259,10 +264,21 @@ export const useWorkflowRun = () => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === node?.parentId) const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] - currIteration?.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) + if (!data.parallel_run_id) { + currIteration?.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + } + else { + if (!iterParallelLogMap.has(data.parallel_run_id)) + iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) + else + iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) + setIterParallelLogMap(iterParallelLogMap) + if (iterations) + iterations.details = Array.from(iterParallelLogMap.values()) + } })) } else { @@ -309,6 +325,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -317,21 +335,21 @@ export const useWorkflowRun = () => { const nodes = getNodes() const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId if (nodeParentId) { - setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const tracing = draft.tracing! - const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + if (!data.execution_metadata.parallel_mode_run_id) { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - if (iterations && iterations.details) { - const iterationIndex = data.execution_metadata?.iteration_index || 0 - if (!iterations.details[iterationIndex]) - iterations.details[iterationIndex] = [] + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] - const currIteration = iterations.details[iterationIndex] - const nodeIndex = currIteration.findIndex(node => - node.node_id === data.node_id && ( - node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), - ) - if (data.status === NodeRunningStatus.Succeeded) { + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], @@ -344,8 +362,40 @@ export const useWorkflowRun = () => { } as any) } } - } - })) + })) + } + else { + // open parallel mode + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + + if (iterations && iterations.details) { + const iterRunID = data.execution_metadata?.parallel_mode_run_id + + const currIteration = iterParallelLogMap.get(iterRunID) + const nodeIndex = currIteration?.findIndex(node => + node.node_id === data.node_id && ( + node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), + ) + if (currIteration) { + if (nodeIndex !== undefined && nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + setIterParallelLogMap(iterParallelLogMap) + iterations.details = Array.from(iterParallelLogMap.values()) + } + })) + } } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { @@ -379,6 +429,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -388,6 +439,7 @@ export const useWorkflowRun = () => { transform, } = store.getState() const nodes = getNodes() + setIterTimes(DEFAULT_ITER_TIMES) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -431,6 +483,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterTimes, + setIterTimes, } = workflowStore.getState() const { data } = params @@ -445,13 +499,14 @@ export const useWorkflowRun = () => { if (iteration.details!.length >= iteration.metadata.iterator_length!) return } - iteration?.details!.push([]) + if (!data.parallel_mode_run_id) + iteration?.details!.push([]) })) const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - - currentNode.data._iterationIndex = data.index > 0 ? data.index : 1 + currentNode.data._iterationIndex = iterTimes + setIterTimes(iterTimes + 1) }) setNodes(newNodes) @@ -464,6 +519,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -480,7 +536,7 @@ export const useWorkflowRun = () => { }) } })) - + setIterTimes(DEFAULT_ITER_TIMES) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index 6459cf8056..b2f815a325 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string title: JSX.Element | string | DefaultTFuncReturn + tooltip?: React.ReactNode isSubTitle?: boolean - tooltip?: string supportFold?: boolean children?: JSX.Element | string | null operations?: JSX.Element inline?: boolean } -const Filed: FC = ({ +const Field: FC = ({ className, title, isSubTitle, @@ -60,4 +60,4 @@ const Filed: FC = ({ ) } -export default React.memo(Filed) +export default React.memo(Field) diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index bd5921c735..e864c419e2 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,6 +25,7 @@ import { useToolIcon, } from '../../hooks' import { useNodeIterationInteractions } from '../iteration/use-interactions' +import type { IterationNodeType } from '../iteration/types' import { NodeSourceHandle, NodeTargetHandle, @@ -34,6 +35,7 @@ import NodeControl from './components/node-control' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' +import Tooltip from '@/app/components/base/tooltip' type BaseNodeProps = { children: ReactElement @@ -166,9 +168,27 @@ const BaseNode: FC = ({ />
- {data.title} +
+ {data.title} +
+ { + data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && ( + +
+ {t('workflow.nodes.iteration.parallelModeEnableTitle')} +
+ {t('workflow.nodes.iteration.parallelModeEnableDesc')} +
} + > +
+ {t('workflow.nodes.iteration.parallelModeUpper')} +
+ + ) + } { data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && ( diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 3afa52d06e..cdef268adb 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -1,7 +1,10 @@ -import { BlockEnum } from '../../types' +import { BlockEnum, ErrorHandleMode } from '../../types' import type { NodeDefault } from '../../types' import type { IterationNodeType } from './types' -import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' +import { + ALL_CHAT_AVAILABLE_BLOCKS, + ALL_COMPLETION_AVAILABLE_BLOCKS, +} from '@/app/components/workflow/constants' const i18nPrefix = 'workflow' const nodeDefault: NodeDefault = { @@ -10,25 +13,45 @@ const nodeDefault: NodeDefault = { iterator_selector: [], output_selector: [], _children: [], + _isShowTips: false, + is_parallel: false, + parallel_nums: 10, + error_handle_mode: ErrorHandleMode.Terminated, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS - : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter( + type => type !== BlockEnum.End, + ) return nodes }, getAvailableNextNodes(isChatMode: boolean) { - const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS return nodes }, checkValid(payload: IterationNodeType, t: any) { let errorMessages = '' - if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) }) + if ( + !errorMessages + && (!payload.iterator_selector || payload.iterator_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.input`), + }) + } - if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) }) + if ( + !errorMessages + && (!payload.output_selector || payload.output_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.output`), + }) + } return { isValid: !errorMessages, diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index 48a005a261..fda033b87a 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,12 +8,16 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' import cn from '@/utils/classnames' import type { NodeProps } from '@/app/components/workflow/types' +import Toast from '@/app/components/base/toast' + +const i18nPrefix = 'workflow.nodes.iteration' const Node: FC> = ({ id, @@ -22,11 +26,20 @@ const Node: FC> = ({ const { zoom } = useViewport() const nodesInitialized = useNodesInitialized() const { handleNodeIterationRerender } = useNodeIterationInteractions() + const { t } = useTranslation() useEffect(() => { if (nodesInitialized) handleNodeIterationRerender(id) - }, [nodesInitialized, id, handleNodeIterationRerender]) + if (data.is_parallel && data._isShowTips) { + Toast.notify({ + type: 'warning', + message: t(`${i18nPrefix}.answerNodeWarningDesc`), + duration: 5000, + }) + data._isShowTips = false + } + }, [nodesInitialized, id, handleNodeIterationRerender, data, t]) return (
> = ({ data, }) => { const { t } = useTranslation() - + const responseMethod = [ + { + value: ErrorHandleMode.Terminated, + name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`), + }, + { + value: ErrorHandleMode.ContinueOnError, + name: t(`${i18nPrefix}.ErrorMethod.continueOnError`), + }, + { + value: ErrorHandleMode.RemoveAbnormalOutput, + name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`), + }, + ] const { readOnly, inputs, @@ -47,6 +66,9 @@ const Panel: FC> = ({ setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } = useConfig(id, data) return ( @@ -87,6 +109,39 @@ const Panel: FC> = ({ />
+
+ {t(`${i18nPrefix}.parallelPanelDesc`)}
} inline> + + + + { + inputs.is_parallel && (
+ {t(`${i18nPrefix}.MaxParallelismDesc`)}
}> +
+ { changeParallelNums(Number(e.target.value)) }} /> + +
+ + + ) + } +
+ +
+ +
+ + + +
+ {isShowSingleRun && ( { @@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => { }) }, [iteratorInputKey, runInputData, setRunInputData]) + const changeParallel = useCallback((value: boolean) => { + const newInputs = produce(inputs, (draft) => { + draft.is_parallel = value + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + const changeErrorResponseMode = useCallback((item: Item) => { + const newInputs = produce(inputs, (draft) => { + draft.error_handle_mode = item.value as ErrorHandleMode + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const changeParallelNums = useCallback((num: number) => { + const newInputs = produce(inputs, (draft) => { + draft.parallel_nums = num + }) + setInputs(newInputs) + }, [inputs, setInputs]) return { readOnly, inputs, @@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => { setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } } diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 58a4561e2c..5d932a1ba2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer' import { uniqBy } from 'lodash-es' import { useWorkflowRun } from '../../hooks' import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' +import { useWorkflowStore } from '../../store' +import { DEFAULT_ITER_TIMES } from '../../constants' import type { ChatItem, Inputs, @@ -43,6 +45,7 @@ export const useChat = ( const { notify } = useToastContext() const { handleRun } = useWorkflowRun() const hasStopResponded = useRef(false) + const workflowStore = useWorkflowStore() const conversationId = useRef('') const taskIdRef = useRef('') const [chatList, setChatList] = useState(prevChatList || []) @@ -52,6 +55,9 @@ export const useChat = ( const [suggestedQuestions, setSuggestQuestions] = useState([]) const suggestedQuestionsAbortControllerRef = useRef(null) + const { + setIterTimes, + } = workflowStore.getState() useEffect(() => { setAutoFreeze(false) return () => { @@ -102,15 +108,16 @@ export const useChat = ( handleResponding(false) if (stopChat && taskIdRef.current) stopChat(taskIdRef.current) - + setIterTimes(DEFAULT_ITER_TIMES) if (suggestedQuestionsAbortControllerRef.current) suggestedQuestionsAbortControllerRef.current.abort() - }, [handleResponding, stopChat]) + }, [handleResponding, setIterTimes, stopChat]) const handleRestart = useCallback(() => { conversationId.current = '' taskIdRef.current = '' handleStop() + setIterTimes(DEFAULT_ITER_TIMES) const newChatList = config?.opening_statement ? [{ id: `${Date.now()}`, @@ -126,6 +133,7 @@ export const useChat = ( config, handleStop, handleUpdateChatList, + setIterTimes, ]) const updateCurrentQA = useCallback(({ diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 9e636e902b..89db43fa35 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -60,36 +60,67 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe }, [notify, getResultCallback]) const formatNodeList = useCallback((list: NodeTracing[]) => { - const allItems = list.reverse() + const allItems = [...list].reverse() const result: NodeTracing[] = [] - allItems.forEach((item) => { - const { node_type, execution_metadata } = item - if (node_type !== BlockEnum.Iteration) { - const isInIteration = !!execution_metadata?.iteration_id + const groupMap = new Map() - if (isInIteration) { - const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) - const iterationDetails = iterationNode?.details - const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - - if (Array.isArray(iterationDetails)) { - if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) - iterationDetails[currentIterationIndex] = [item] - else - iterationDetails[currentIterationIndex].push(item) - } - return - } - // not in iteration - result.push(item) - - return - } + const processIterationNode = (item: NodeTracing) => { result.push({ ...item, details: [], }) + } + const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => { + if (!groupMap.has(runId)) + groupMap.set(runId, [item]) + else + groupMap.get(runId)!.push(item) + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + + iterationNode.details = Array.from(groupMap.values()) + } + const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { + const { details } = iterationNode + if (details) { + if (!details[index]) + details[index] = [item] + else + details[index].push(item) + } + + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + } + const processNonIterationNode = (item: NodeTracing) => { + const { execution_metadata } = item + if (!execution_metadata?.iteration_id) { + result.push(item) + return + } + + const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id) + if (!iterationNode || !Array.isArray(iterationNode.details)) + return + + const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata + + if (parallel_mode_run_id) + updateParallelModeGroup(parallel_mode_run_id, item, iterationNode) + else + updateSequentialModeGroup(iteration_index, item, iterationNode) + } + + allItems.forEach((item) => { + item.node_type === BlockEnum.Iteration + ? processIterationNode(item) + : processNonIterationNode(item) }) + return result }, []) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index 7e2f6cbc00..c4cd909f2e 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightSLine, RiCloseLine, + RiErrorWarningLine, } from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import TracingPanel from './tracing-panel' @@ -27,7 +28,7 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() - const [expandedIterations, setExpandedIterations] = useState>([]) + const [expandedIterations, setExpandedIterations] = useState>({}) const toggleIteration = useCallback((index: number) => { setExpandedIterations(prev => ({ @@ -71,10 +72,19 @@ const IterationResultPanel: FC = ({ {t(`${i18nPrefix}.iteration`)} {index + 1} - + { + iteration.some(item => item.status === 'failed') + ? ( + + ) + : (< RiArrowRightSLine className={ + cn( + 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', + expandedIterations[index] && 'transform rotate-90', + )} /> + ) + } + {expandedIterations[index] &&
= ({ return iteration_length } + const getErrorCount = (details: NodeTracing[][] | undefined) => { + if (!details || details.length === 0) + return 0 + return details.reduce((acc, iteration) => { + if (iteration.some(item => item.status === 'failed')) + acc++ + return acc + }, 0) + } useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) @@ -136,7 +145,12 @@ const NodePanel: FC = ({ onClick={handleOnShowIterationDetail} > -
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}
+
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && ( + <> + {t('workflow.nodes.iteration.comma')} + {t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })} + + )}
{justShowIterationNavArrow ? ( diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index c2a6823e6b..c4a625c777 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -21,6 +21,7 @@ import type { WorkflowRunningData, } from './types' import { WorkflowContext } from './context' +import type { NodeTracing } from '@/types/workflow' // #TODO chatVar# // const MOCK_DATA = [ @@ -166,6 +167,10 @@ type Shape = { setShowImportDSLModal: (showImportDSLModal: boolean) => void showTips: string setShowTips: (showTips: string) => void + iterTimes: number + setIterTimes: (iterTimes: number) => void + iterParallelLogMap: Map + setIterParallelLogMap: (iterParallelLogMap: Map) => void } export const createWorkflowStore = () => { @@ -281,6 +286,11 @@ export const createWorkflowStore = () => { setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), showTips: '', setShowTips: showTips => set(() => ({ showTips })), + iterTimes: 1, + setIterTimes: iterTimes => set(() => ({ iterTimes })), + iterParallelLogMap: new Map(), + setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })), + })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 81bec41eac..9b6ad033bf 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -36,7 +36,11 @@ export enum ControlMode { Pointer = 'pointer', Hand = 'hand', } - +export enum ErrorHandleMode { + Terminated = 'terminated', + ContinueOnError = 'continue-on-error', + RemoveAbnormalOutput = 'remove-abnormal-output', +} export type Branch = { id: string name: string diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 91656e3bbc..aaf333f4d7 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -19,7 +19,7 @@ import type { ToolWithProvider, ValueSelector, } from './types' -import { BlockEnum } from './types' +import { BlockEnum, ErrorHandleMode } from './types' import { CUSTOM_NODE, ITERATION_CHILDREN_Z_INDEX, @@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { }) } - if (node.data.type === BlockEnum.Iteration) - node.data._children = iterationNodeMap[node.id] || [] + if (node.data.type === BlockEnum.Iteration) { + const iterationNodeData = node.data as IterationNodeType + iterationNodeData._children = iterationNodeMap[node.id] || [] + iterationNodeData.is_parallel = iterationNodeData.is_parallel || false + iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10 + iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated + } return node }) diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index ea8355500a..1c6639aba0 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterations', currentIteration: 'Current Iteration', + comma: ', ', + error_one: '{{count}} Error', + error_other: '{{count}} Errors', + parallelMode: 'Parallel Mode', + parallelModeUpper: 'PARALLEL MODE', + parallelModeEnableTitle: 'Parallel Mode Enabled', + parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.', + parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.', + MaxParallelismTitle: 'Maximum parallelism', + MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', + errorResponseMethod: 'Error response method', + ErrorMethod: { + operationTerminated: 'terminated', + continueOnError: 'continue-on-error', + removeAbnormalOutput: 'remove-abnormal-output', + }, + answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', }, note: { addNote: 'Add Note', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 515d0fe235..1229ba8c03 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}}个迭代', iteration_other: '{{count}}个迭代', currentIteration: '当前迭代', + comma: ',', + error_one: '{{count}}个失败', + error_other: '{{count}}个失败', + parallelMode: '并行模式', + parallelModeUpper: '并行模式', + parallelModeEnableTitle: '并行模式启用', + parallelModeEnableDesc: '启用并行模式时迭代内的任务支持并行执行。你可以在右侧的属性面板中进行配置。', + parallelPanelDesc: '在并行模式下,迭代中的任务支持并行执行。', + MaxParallelismTitle: '最大并行度', + MaxParallelismDesc: '最大并行度用于控制单次迭代中同时执行的任务数量。', + errorResponseMethod: '错误响应方法', + ErrorMethod: { + operationTerminated: '错误时终止', + continueOnError: '忽略错误并继续', + removeAbnormalOutput: '移除错误输出', + }, + answerNodeWarningDesc: '并行模式警告:在迭代中,回答节点、会话变量赋值和工具持久读/写操作可能会导致异常。', }, note: { addNote: '添加注释', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 810026b084..3c0675b605 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -19,6 +19,7 @@ export type NodeTracing = { process_data: any outputs?: any status: string + parallel_run_id?: string error?: string elapsed_time: number execution_metadata: { @@ -31,6 +32,7 @@ export type NodeTracing = { parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + parallel_mode_run_id?: string } metadata: { iterator_length: number @@ -121,6 +123,7 @@ export type NodeStartedResponse = { id: string node_id: string iteration_id?: string + parallel_run_id?: string node_type: string index: number predecessor_node_id?: string @@ -166,6 +169,7 @@ export type NodeFinishedResponse = { parallel_start_node_id?: string iteration_index?: number iteration_id?: string + parallel_mode_run_id: string } created_at: number files?: FileResponse[] @@ -200,6 +204,7 @@ export type IterationNextResponse = { output: any extras?: any created_at: number + parallel_mode_run_id: string execution_metadata: { parallel_id?: string }