feat(llm_node): allow to use image file directly in the prompt.

This commit is contained in:
-LAN- 2024-11-14 18:34:16 +08:00
parent 2a58cc59c0
commit 600d111e8f
2 changed files with 651 additions and 145 deletions

View File

@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import (
@ -30,10 +36,13 @@ from core.variables import (
FileSegment,
NoneSegment,
ObjectSegment,
SegmentGroup,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
@ -62,14 +71,18 @@ from .exc import (
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
@ -131,9 +144,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = None
prompt_messages, stop = self._fetch_prompt_messages(
system_query=query,
inputs=inputs,
files=files,
user_query=query,
user_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -203,7 +215,7 @@ class LLMNode(BaseNode[LLMNodeData]):
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()
@ -519,9 +531,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
user_query: str | None = None,
user_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -529,60 +540,161 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = inputs or {}
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=system_query or "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
)
stop = model_config.stop
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context))
# Get memory messages for chat mode
memory_messages = self._handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)]))
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context))
# Get memory text for completion model
memory_text = self._handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
continue
if not isinstance(prompt_message.content, str):
if isinstance(prompt_message.content, list):
prompt_message_content = []
for content_item in prompt_message.content or []:
for content_item in prompt_message.content:
# Skip image if vision is disabled
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
continue
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
stop = model_config.stop
return filtered_prompt_messages, stop
def _handle_memory_chat_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
@ -715,3 +827,121 @@ class LLMNode(BaseNode[LLMNodeData]):
}
},
}
def _handle_list_messages(
self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
segment_group = _render_basic_message(
template=message.text,
context=context,
variable_pool=self.graph_runtime_state.variable_pool,
)
# Process segments for images
image_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type == FileType.IMAGE:
image_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
image_contents.append(image_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type == FileType.IMAGE:
image_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
image_contents.append(image_content)
# Create message with text from all segments
prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role)
prompt_messages.append(prompt_message)
if image_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=image_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _handle_completion_template(
self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
else:
result_text = _render_basic_message(
template=template.text,
context=context,
variable_pool=self.graph_runtime_state.variable_pool,
).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _render_basic_message(
*,
template: str,
context: str | None,
variable_pool: VariablePool,
) -> SegmentGroup:
if not template:
return SegmentGroup(value=[])
if context:
template = template.replace("{#context#}", context)
return variable_pool.convert_template(template)

View File

@ -1,125 +1,401 @@
from collections.abc import Sequence
from typing import Optional
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
class TestLLMNode:
@pytest.fixture
def llm_node(self):
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
class MockTokenBufferMemory:
def __init__(self, history_messages=None):
self.history_messages = history_messages or []
def test_fetch_files_with_file_segment(self, llm_node):
file = File(
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
]
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
fake_query = faker.sentence()
def test_fetch_files_with_array_file_segment(self, llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=None,
memory=None,
model_config=model_config,
prompt_template=prompt_template,
memory_config=None,
vision_enabled=False,
vision_detail=fake_vision_detail,
)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
assert prompt_messages == [UserPromptMessage(content=fake_query)]
def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
def test_fetch_files_with_array_any_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
# Generate fake values for prompt template
fake_user_prompt = faker.sentence()
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
random_context = faker.sentence()
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Generate fake values for vision
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
fake_prompt_image_url = faker.url()
def test_fetch_files_with_non_existent_variable(self, llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Setup prompt template with image variable reference
prompt_template = [
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.images#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
]
llm_node.node_data.prompt_template = prompt_template
# Setup vision files
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="1",
)
]
# Setup prompt image in variable pool
prompt_image = File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="2",
)
prompt_images = [
File(
id="3",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="3",
),
File(
id="4",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="4",
),
]
llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image)
llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images)
# Setup memory configuration with random window size
window_size = faker.random_int(min=1, max=3)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=window_size),
query_prompt_template=None,
)
# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
memory = MockTokenBufferMemory(history_messages=mock_history)
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=random_context,
memory=memory,
model_config=model_config,
prompt_template=prompt_template,
memory_config=memory_config,
vision_enabled=True,
vision_detail=fake_vision_detail,
)
# Build expected messages
expected_messages = [
# Base template messages
SystemPromptMessage(content=random_context),
# Image from variable pool in prompt template
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
]
),
AssistantPromptMessage(content=fake_assistant_prompt),
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
]
),
]
# Add memory messages based on window size
expected_messages.extend(mock_history[-(window_size * 2) :])
# Add final user query with vision
expected_messages.append(
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
)
)
# Verify the result
assert prompt_messages == expected_messages