From f68d6bd5e218056a2bb4a04c234e4935f695683a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 23:35:20 +0800 Subject: [PATCH] refactor(node.py): streamline template rendering Removed the `_render_basic_message` function and integrated its logic directly into the `LLMNode` class. This reduces redundancy and simplifies the handling of message templates by utilizing `convert_template` more directly. This change enhances code readability and maintainability. --- api/core/workflow/nodes/llm/node.py | 36 ++++++++--------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index a5620dbc01..d6e1019ce9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -36,7 +36,6 @@ from core.variables import ( FileSegment, NoneSegment, ObjectSegment, - SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID @@ -878,11 +877,11 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages.append(prompt_message) else: # Get segment group from basic message - segment_group = _render_basic_message( - template=message.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ) + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = self.graph_runtime_state.variable_pool.convert_template(template) # Process segments for images file_contents = [] @@ -926,11 +925,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable_pool=self.graph_runtime_state.variable_pool, ) else: - result_text = _render_basic_message( - template=template.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ).text + if context: + template = template.text.replace("{#context#}", context) + else: + template = template.text + result_text = self.graph_runtime_state.variable_pool.convert_template(template).text prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_messages.append(prompt_message) return prompt_messages @@ -967,18 +966,3 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text - - -def _render_basic_message( - *, - template: str, - context: str | None, - variable_pool: VariablePool, -) -> SegmentGroup: - if not template: - return SegmentGroup(value=[]) - - if context: - template = template.replace("{#context#}", context) - - return variable_pool.convert_template(template)