refactor(core): decouple LLMNode prompt handling

Moved prompt handling functions out of the `LLMNode` class to improve modularity and separation of concerns. This refactor allows better reuse and testing of prompt-related functions. Adjusted existing logic to fetch queries and handle context and memory configurations more effectively. Updated tests to align with the new structure and ensure continued functionality.
This commit is contained in:
-LAN- 2024-11-15 00:18:36 +08:00
parent f68d6bd5e2
commit 4e360ec19a
3 changed files with 210 additions and 153 deletions

View File

@ -38,7 +38,6 @@ from core.variables import (
ObjectSegment, ObjectSegment,
StringSegment, StringSegment,
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -135,10 +134,7 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch prompt messages # fetch prompt messages
if self.node_data.memory: if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query = self.node_data.memory.query_prompt_template
if not query:
raise VariableNotFoundError("Query not found")
query = query.text
else: else:
query = None query = None
@ -152,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )
process_data = { process_data = {
@ -550,15 +548,25 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None, memory_config: MemoryConfig | None = None,
vision_enabled: bool = False, vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = [] prompt_messages = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
# For chat model # For chat model
prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context)) prompt_messages.extend(
_handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
# Get memory messages for chat mode # Get memory messages for chat mode
memory_messages = self._handle_memory_chat_mode( memory_messages = _handle_memory_chat_mode(
memory=memory, memory=memory,
memory_config=memory_config, memory_config=memory_config,
model_config=model_config, model_config=model_config,
@ -568,14 +576,34 @@ class LLMNode(BaseNode[LLMNodeData]):
# Add current query to the prompt messages # Add current query to the prompt messages
if user_query: if user_query:
prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)])) message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model # For completion model
prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context)) prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model # Get memory text for completion model
memory_text = self._handle_memory_completion_mode( memory_text = _handle_memory_completion_mode(
memory=memory, memory=memory,
memory_config=memory_config, memory_config=memory_config,
model_config=model_config, model_config=model_config,
@ -628,7 +656,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if ( if (
( (
content_item.type == PromptMessageContentType.IMAGE content_item.type == PromptMessageContentType.IMAGE
and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features) and ModelFeature.VISION not in model_config.model_schema.features
) )
or ( or (
content_item.type == PromptMessageContentType.DOCUMENT content_item.type == PromptMessageContentType.DOCUMENT
@ -662,73 +690,6 @@ class LLMNode(BaseNode[LLMNodeData]):
stop = model_config.stop stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop
def _handle_memory_chat_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
@classmethod @classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle provider_model_bundle = model_instance.provider_model_bundle
@ -862,78 +823,6 @@ class LLMNode(BaseNode[LLMNodeData]):
}, },
} }
def _handle_list_messages(
self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = self.graph_runtime_state.variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _handle_completion_template(
self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
else:
if context:
template = template.text.replace("{#context#}", context)
else:
template = template.text
result_text = self.graph_runtime_state.variable_pool.convert_template(template).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role: match role:
@ -966,3 +855,165 @@ def _render_jinja2_message(
) )
result_text = code_execute_resp["result"] result_text = code_execute_resp["result"]
return result_text return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
system_query=query, user_query=query,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
files=files, user_files=files,
vision_enabled=node_data.vision.enabled, vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
) )
# handle invoke result # handle invoke result

View File

@ -240,6 +240,8 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
memory_config=None, memory_config=None,
vision_enabled=False, vision_enabled=False,
vision_detail=fake_vision_detail, vision_detail=fake_vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
) )
assert prompt_messages == [UserPromptMessage(content=fake_query)] assert prompt_messages == [UserPromptMessage(content=fake_query)]
@ -368,7 +370,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
description="Prompt template with variable selector of File", description="Prompt template with variable selector of File",
user_query=fake_query, user_query=fake_query,
user_files=[], user_files=[],
vision_enabled=True, vision_enabled=False,
vision_detail=fake_vision_detail, vision_detail=fake_vision_detail,
features=[ModelFeature.VISION], features=[ModelFeature.VISION],
window_size=fake_window_size, window_size=fake_window_size,
@ -471,6 +473,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
memory_config=memory_config, memory_config=memory_config,
vision_enabled=scenario.vision_enabled, vision_enabled=scenario.vision_enabled,
vision_detail=scenario.vision_detail, vision_detail=scenario.vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
) )
# Verify the result # Verify the result