diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index a5620dbc01..d6e1019ce9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -36,7 +36,6 @@ from core.variables import ( FileSegment, NoneSegment, ObjectSegment, - SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID @@ -878,11 +877,11 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages.append(prompt_message) else: # Get segment group from basic message - segment_group = _render_basic_message( - template=message.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ) + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = self.graph_runtime_state.variable_pool.convert_template(template) # Process segments for images file_contents = [] @@ -926,11 +925,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable_pool=self.graph_runtime_state.variable_pool, ) else: - result_text = _render_basic_message( - template=template.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ).text + if context: + template = template.text.replace("{#context#}", context) + else: + template = template.text + result_text = self.graph_runtime_state.variable_pool.convert_template(template).text prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_messages.append(prompt_message) return prompt_messages @@ -967,18 +966,3 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text - - -def _render_basic_message( - *, - template: str, - context: str | None, - variable_pool: VariablePool, -) -> SegmentGroup: - if not template: - return SegmentGroup(value=[]) - - if context: - template = template.replace("{#context#}", context) - - return variable_pool.convert_template(template)