diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ada0b14ce4..8f58af00ef 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -130,15 +130,14 @@ class GraphEngine: yield GraphRunStartedEvent() try: - stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] if self.init_params.workflow_type == WorkflowType.CHAT: - stream_processor_cls = AnswerStreamProcessor + stream_processor = AnswerStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) else: - stream_processor_cls = EndStreamProcessor - - stream_processor = stream_processor_cls( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool - ) + stream_processor = EndStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) # run graph generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index bce28c5fcb..bc4b056148 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, }: answer_dependencies[answer_node_id].append(source_node_id) else: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index e3889941ca..8a768088da 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor): super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} - for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + for answer_node_id in self.generate_routes.answer_generate_route: self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 52d0358c76..36c3fe180a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -41,7 +41,6 @@ class StreamProcessor(ABC): continue else: unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index e543d02dd7..a05cc44c99 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from pydantic import BaseModel, Field @@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" - value_selector: list[str] = Field(..., description="value selector") + value_selector: Sequence[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk):