mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat(tests): refactor LLMNode tests for clarity
Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach.
This commit is contained in:
parent
9f0f82cb1c
commit
97fab7649b
|
@ -39,6 +39,7 @@ from core.workflow.nodes.llm.node import LLMNode
|
|||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowType
|
||||
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||
|
||||
|
||||
class MockTokenBufferMemory:
|
||||
|
@ -224,7 +225,6 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
|||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
related_id="1",
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -280,13 +280,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
|
||||
# Test scenarios covering different file input combinations
|
||||
test_scenarios = [
|
||||
{
|
||||
"description": "No files",
|
||||
"user_query": fake_query,
|
||||
"user_files": [],
|
||||
"features": [],
|
||||
"window_size": fake_window_size,
|
||||
"prompt_template": [
|
||||
LLMNodeTestScenario(
|
||||
description="No files",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
features=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=None,
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
|
@ -303,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
"expected_messages": [
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
|
@ -312,11 +314,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
+ [
|
||||
UserPromptMessage(content=fake_query),
|
||||
],
|
||||
},
|
||||
{
|
||||
"description": "User files",
|
||||
"user_query": fake_query,
|
||||
"user_files": [
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="User files",
|
||||
user_query=fake_query,
|
||||
user_files=[
|
||||
File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
|
@ -325,11 +327,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
remote_url=fake_remote_url,
|
||||
)
|
||||
],
|
||||
"vision_enabled": True,
|
||||
"vision_detail": fake_vision_detail,
|
||||
"features": [ModelFeature.VISION],
|
||||
"window_size": fake_window_size,
|
||||
"prompt_template": [
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
|
@ -346,7 +348,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
"expected_messages": [
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
|
@ -360,27 +362,27 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||
]
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for scenario in test_scenarios:
|
||||
model_config.model_schema.features = scenario["features"]
|
||||
model_config.model_schema.features = scenario.features
|
||||
|
||||
# Call the method under test
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=fake_query,
|
||||
user_files=scenario["user_files"],
|
||||
user_query=scenario.user_query,
|
||||
user_files=scenario.user_files,
|
||||
context=fake_context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=scenario["prompt_template"],
|
||||
prompt_template=scenario.prompt_template,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
vision_enabled=scenario.vision_enabled,
|
||||
vision_detail=scenario.vision_detail,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}"
|
||||
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||
assert (
|
||||
prompt_messages == scenario["expected_messages"]
|
||||
), f"Message content mismatch in scenario: {scenario['description']}"
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
|
||||
|
||||
class LLMNodeTestScenario(BaseModel):
|
||||
"""Test scenario for LLM node testing."""
|
||||
|
||||
description: str = Field(..., description="Description of the test scenario")
|
||||
user_query: str = Field(..., description="User query input")
|
||||
user_files: list[File] = Field(default_factory=list, description="List of user files")
|
||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||
features: list[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||
window_size: int = Field(..., description="Window size for memory")
|
||||
prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||
expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing")
|
Loading…
Reference in New Issue
Block a user