feat(tests): refactor LLMNode tests for clarity

Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach.
This commit is contained in:
-LAN- 2024-11-14 19:54:21 +08:00
parent 9f0f82cb1c
commit 97fab7649b
2 changed files with 52 additions and 30 deletions

View File

@ -39,6 +39,7 @@ from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
class MockTokenBufferMemory: class MockTokenBufferMemory:
@ -224,7 +225,6 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
filename="test1.jpg", filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url, remote_url=fake_remote_url,
related_id="1",
) )
] ]
@ -280,13 +280,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Test scenarios covering different file input combinations # Test scenarios covering different file input combinations
test_scenarios = [ test_scenarios = [
{ LLMNodeTestScenario(
"description": "No files", description="No files",
"user_query": fake_query, user_query=fake_query,
"user_files": [], user_files=[],
"features": [], features=[],
"window_size": fake_window_size, vision_enabled=False,
"prompt_template": [ vision_detail=None,
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage( LLMNodeChatModelMessage(
text=fake_context, text=fake_context,
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
@ -303,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
edition_type="basic", edition_type="basic",
), ),
], ],
"expected_messages": [ expected_messages=[
SystemPromptMessage(content=fake_context), SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context), UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt), AssistantPromptMessage(content=fake_assistant_prompt),
@ -312,11 +314,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
+ [ + [
UserPromptMessage(content=fake_query), UserPromptMessage(content=fake_query),
], ],
}, ),
{ LLMNodeTestScenario(
"description": "User files", description="User files",
"user_query": fake_query, user_query=fake_query,
"user_files": [ user_files=[
File( File(
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
@ -325,11 +327,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
remote_url=fake_remote_url, remote_url=fake_remote_url,
) )
], ],
"vision_enabled": True, vision_enabled=True,
"vision_detail": fake_vision_detail, vision_detail=fake_vision_detail,
"features": [ModelFeature.VISION], features=[ModelFeature.VISION],
"window_size": fake_window_size, window_size=fake_window_size,
"prompt_template": [ prompt_template=[
LLMNodeChatModelMessage( LLMNodeChatModelMessage(
text=fake_context, text=fake_context,
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
@ -346,7 +348,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
edition_type="basic", edition_type="basic",
), ),
], ],
"expected_messages": [ expected_messages=[
SystemPromptMessage(content=fake_context), SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context), UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt), AssistantPromptMessage(content=fake_assistant_prompt),
@ -360,27 +362,27 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
] ]
), ),
], ],
}, ),
] ]
for scenario in test_scenarios: for scenario in test_scenarios:
model_config.model_schema.features = scenario["features"] model_config.model_schema.features = scenario.features
# Call the method under test # Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages( prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query, user_query=scenario.user_query,
user_files=scenario["user_files"], user_files=scenario.user_files,
context=fake_context, context=fake_context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
prompt_template=scenario["prompt_template"], prompt_template=scenario.prompt_template,
memory_config=memory_config, memory_config=memory_config,
vision_enabled=True, vision_enabled=scenario.vision_enabled,
vision_detail=fake_vision_detail, vision_detail=scenario.vision_detail,
) )
# Verify the result # Verify the result
assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert ( assert (
prompt_messages == scenario["expected_messages"] prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario['description']}" ), f"Message content mismatch in scenario: {scenario.description}"

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: list[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: list[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing")