feat(llm_node): allow to use image file directly in the prompt.

This commit is contained in:
-LAN- 2024-11-14 18:34:16 +08:00
parent bab989e3b3
commit d6c9ab8554
2 changed files with 651 additions and 145 deletions

View File

@ -1,4 +1,5 @@
import json import json
import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ( from core.variables import (
@ -30,10 +36,13 @@ from core.variables import (
FileSegment, FileSegment,
NoneSegment, NoneSegment,
ObjectSegment, ObjectSegment,
SegmentGroup,
StringSegment, StringSegment,
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -62,14 +71,18 @@ from .exc import (
InvalidVariableTypeError, InvalidVariableTypeError,
LLMModeRequiredError, LLMModeRequiredError,
LLMNodeError, LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError, VariableNotFoundError,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
@ -131,9 +144,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = None query = None
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
system_query=query, user_query=query,
inputs=inputs, user_files=files,
files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
@ -203,7 +215,7 @@ class LLMNode(BaseNode[LLMNodeData]):
self, self,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close() db.session.close()
@ -519,9 +531,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages( def _fetch_prompt_messages(
self, self,
*, *,
system_query: str | None = None, user_query: str | None = None,
inputs: dict[str, str] | None = None, user_files: Sequence["File"],
files: Sequence["File"],
context: str | None = None, context: str | None = None,
memory: TokenBufferMemory | None = None, memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
@ -529,60 +540,161 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None, memory_config: MemoryConfig | None = None,
vision_enabled: bool = False, vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]: ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
inputs = inputs or {} prompt_messages = []
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) if isinstance(prompt_template, list):
prompt_messages = prompt_transform.get_prompt( # For chat model
prompt_template=prompt_template, prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context))
inputs=inputs,
query=system_query or "", # Get memory messages for chat mode
files=files, memory_messages = self._handle_memory_chat_mode(
context=context, memory=memory,
memory_config=memory_config, memory_config=memory_config,
memory=memory, model_config=model_config,
model_config=model_config, )
) # Extend prompt_messages with memory messages
stop = model_config.stop prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)]))
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context))
# Get memory text for completion model
memory_text = self._handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = [] filtered_prompt_messages = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.is_empty(): if isinstance(prompt_message.content, list):
continue
if not isinstance(prompt_message.content, str):
prompt_message_content = [] prompt_message_content = []
for content_item in prompt_message.content or []: for content_item in prompt_message.content:
# Skip image if vision is disabled # Skip image if vision is disabled
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
continue continue
prompt_message_content.append(content_item)
if isinstance(content_item, ImagePromptMessageContent): if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message) filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages: if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError( raise NoPromptFoundError(
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop
def _handle_memory_chat_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
self,
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = self._calculate_rest_token([], model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
@classmethod @classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle provider_model_bundle = model_instance.provider_model_bundle
@ -715,3 +827,121 @@ class LLMNode(BaseNode[LLMNodeData]):
} }
}, },
} }
def _handle_list_messages(
self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
segment_group = _render_basic_message(
template=message.text,
context=context,
variable_pool=self.graph_runtime_state.variable_pool,
)
# Process segments for images
image_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type == FileType.IMAGE:
image_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
image_contents.append(image_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type == FileType.IMAGE:
image_content = file_manager.to_prompt_message_content(
file, image_detail_config=self.node_data.vision.configs.detail
)
image_contents.append(image_content)
# Create message with text from all segments
prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role)
prompt_messages.append(prompt_message)
if image_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=image_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _handle_completion_template(
self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str]
) -> Sequence[PromptMessage]:
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
variable_pool=self.graph_runtime_state.variable_pool,
)
else:
result_text = _render_basic_message(
template=template.text,
context=context,
variable_pool=self.graph_runtime_state.variable_pool,
).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _render_basic_message(
*,
template: str,
context: str | None,
variable_pool: VariablePool,
) -> SegmentGroup:
if not template:
return SegmentGroup(value=[])
if context:
template = template.replace("{#context#}", context)
return variable_pool.convert_template(template)

View File

@ -1,125 +1,401 @@
from collections.abc import Sequence
from typing import Optional
import pytest import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
class TestLLMNode: class MockTokenBufferMemory:
@pytest.fixture def __init__(self, history_messages=None):
def llm_node(self): self.history_messages = history_messages or []
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
def test_fetch_files_with_file_segment(self, llm_node): def get_history_prompt_messages(
file = File( self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1", id="1",
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test.jpg", filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE, transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1", related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="1",
) )
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) ]
result = llm_node._fetch_files(selector=["sys", "files"]) fake_query = faker.sentence()
assert result == [file]
def test_fetch_files_with_array_file_segment(self, llm_node): prompt_messages, _ = llm_node._fetch_prompt_messages(
files = [ user_query=fake_query,
File( user_files=files,
id="1", context=None,
tenant_id="test", memory=None,
type=FileType.IMAGE, model_config=model_config,
filename="test1.jpg", prompt_template=prompt_template,
transfer_method=FileTransferMethod.LOCAL_FILE, memory_config=None,
related_id="1", vision_enabled=False,
), vision_detail=fake_vision_detail,
File( )
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"]) assert prompt_messages == [UserPromptMessage(content=fake_query)]
assert result == files
def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"]) def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
assert result == [] # Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
def test_fetch_files_with_array_any_segment(self, llm_node): # Generate fake values for prompt template
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) fake_user_prompt = faker.sentence()
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
random_context = faker.sentence()
result = llm_node._fetch_files(selector=["sys", "files"]) # Generate fake values for vision
assert result == [] fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
fake_prompt_image_url = faker.url()
def test_fetch_files_with_non_existent_variable(self, llm_node): # Setup prompt template with image variable reference
result = llm_node._fetch_files(selector=["sys", "files"]) prompt_template = [
assert result == [] LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{{#input.images#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
]
llm_node.node_data.prompt_template = prompt_template
# Setup vision files
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
related_id="1",
)
]
# Setup prompt image in variable pool
prompt_image = File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="2",
)
prompt_images = [
File(
id="3",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="3",
),
File(
id="4",
tenant_id="test",
type=FileType.IMAGE,
filename="prompt_image.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_prompt_image_url,
related_id="4",
),
]
llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image)
llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images)
# Setup memory configuration with random window size
window_size = faker.random_int(min=1, max=3)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=window_size),
query_prompt_template=None,
)
# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
memory = MockTokenBufferMemory(history_messages=mock_history)
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=random_context,
memory=memory,
model_config=model_config,
prompt_template=prompt_template,
memory_config=memory_config,
vision_enabled=True,
vision_detail=fake_vision_detail,
)
# Build expected messages
expected_messages = [
# Base template messages
SystemPromptMessage(content=random_context),
# Image from variable pool in prompt template
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
]
),
AssistantPromptMessage(content=fake_assistant_prompt),
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail),
]
),
]
# Add memory messages based on window size
expected_messages.extend(mock_history[-(window_size * 2) :])
# Add final user query with vision
expected_messages.append(
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
)
)
# Verify the result
assert prompt_messages == expected_messages