From 3b53e06e0d8f0792027beb1f10830e7c4529f15c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 30 Oct 2024 16:23:12 +0800 Subject: [PATCH] fix(workflow): refine variable type checks in LLMNode (#10051) --- api/core/workflow/nodes/llm/node.py | 8 +- .../core/workflow/nodes/llm/test_node.py | 125 ++++++++++++++++++ 2 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_node.py diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 472587cb03..b4728e6abf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is None: return [] - if isinstance(variable, FileSegment): + elif isinstance(variable, FileSegment): return [variable.value] - if isinstance(variable, ArrayFileSegment): + elif isinstance(variable, ArrayFileSegment): return variable.value - # FIXME: Temporary fix for empty array, - # all variables added to variable pool should be a Segment instance. - if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: + elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] raise ValueError(f"Invalid variable type: {type(variable)}") diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py new file mode 100644 index 0000000000..def6c2a232 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -0,0 +1,125 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class TestLLMNode: + @pytest.fixture + def llm_node(self): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + def test_fetch_files_with_file_segment(self, llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + def test_fetch_files_with_array_file_segment(self, llm_node): + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + def test_fetch_files_with_none_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_array_any_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_non_existent_variable(self, llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == []