This commit is contained in:
-LAN- 2024-11-15 18:00:11 +08:00 committed by GitHub
commit e7fb51d5a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 959 additions and 261 deletions

View File

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra="ignore",
) )

View File

@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -38,27 +38,23 @@ class ModelConfigConverter:
) )
if model_credentials is None: if model_credentials is None:
if not skip_check: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
if not skip_check: # check model
# check model provider_model = provider_model_bundle.configuration.get_provider_model(
provider_model = provider_model_bundle.configuration.get_provider_model( model=model_config.model, model_type=ModelType.LLM
model=model_config.model, model_type=ModelType.LLM )
)
if provider_model is None: if provider_model is None:
model_name = model_config.model model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE: if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION: elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
@ -76,7 +72,7 @@ class ModelConfigConverter:
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity( return ModelConfigWithCredentialsEntity(

View File

@ -217,6 +217,7 @@ class WorkflowCycleManage:
).total_seconds() ).total_seconds()
db.session.commit() db.session.commit()
db.session.add(workflow_run)
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
db.session.close() db.session.close()

View File

@ -74,6 +74,8 @@ def to_prompt_message_content(
data = _to_url(f) data = _to_url(f)
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _: case _:
raise ValueError("file type f.type is not supported") raise ValueError("file type f.type is not supported")

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -27,7 +28,7 @@ class TokenBufferMemory:
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit

View File

@ -100,10 +100,10 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
): ):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:

View File

@ -1,4 +1,5 @@
from abc import ABC from abc import ABC
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video" VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -107,7 +109,7 @@ class PromptMessage(ABC, BaseModel):
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

View File

@ -87,6 +87,9 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call" STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum): class DefaultParameterName(str, Enum):

View File

@ -2,7 +2,7 @@ import logging
import re import re
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -212,7 +212,7 @@ if you are not sure about the structure.
) )
model_parameters.pop("response_format") model_parameters.pop("response_format")
stop = stop or [] stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"]) stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block) block_prompts = block_prompts.replace("{{block}}", code_block)
@ -408,7 +408,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -479,7 +479,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
@ -601,7 +601,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -647,7 +647,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -694,7 +694,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -742,7 +742,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

View File

@ -8,6 +8,7 @@ features:
- agent-thought - agent-thought
- stream-tool-call - stream-tool-call
- vision - vision
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import cast from typing import cast
from core.model_runtime.entities import ( from core.model_runtime.entities import (
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil: class PromptMessageUtil:
@staticmethod @staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
""" """
Prompt messages to prompt for saving. Prompt messages to prompt for saving.
:param model_mode: model mode :param model_mode: model mode

View File

@ -118,11 +118,11 @@ class FileSegment(Segment):
@property @property
def log(self) -> str: def log(self) -> str:
return str(self.value) return ""
@property @property
def text(self) -> str: def text(self) -> str:
return str(self.value) return ""
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
for item in self.value: for item in self.value:
items.append(item.markdown) items.append(item.markdown)
return "\n".join(items) return "\n".join(items)
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""

View File

@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
class PromptConfig(BaseModel): class PromptConfig(BaseModel):
jinja2_variables: Optional[list[VariableSelector]] = None jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
class LLMNodeChatModelMessage(ChatModelMessage): class LLMNodeChatModelMessage(ChatModelMessage):
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData): class LLMNodeData(BaseNodeData):
model: ModelConfig model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v

View File

@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
class NoPromptFoundError(LLMNodeError): class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration.""" """Raised when no prompt is found in the LLM configuration."""
class NotSupportedPromptTypeError(LLMNodeError):
"""Raised when the prompt type is not supported."""
class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""

View File

@ -1,4 +1,5 @@
import json import json
import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ( from core.variables import (
@ -32,8 +38,9 @@ from core.variables import (
ObjectSegment, ObjectSegment,
StringSegment, StringSegment,
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -62,14 +69,18 @@ from .exc import (
InvalidVariableTypeError, InvalidVariableTypeError,
LLMModeRequiredError, LLMModeRequiredError,
LLMNodeError, LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError, VariableNotFoundError,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch prompt messages # fetch prompt messages
if self.node_data.memory: if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query = self.node_data.memory.query_prompt_template
if not query:
raise VariableNotFoundError("Query not found")
query = query.text
else: else:
query = None query = None
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
system_query=query, user_query=query,
inputs=inputs, user_files=files,
files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )
process_data = { process_data = {
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
) )
return return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run: {e}")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self, self,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close() db.session.close()
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages( def _fetch_prompt_messages(
self, self,
*, *,
system_query: str | None = None, user_query: str | None = None,
inputs: dict[str, str] | None = None, user_files: Sequence["File"],
files: Sequence["File"],
context: str | None = None, context: str | None = None,
memory: TokenBufferMemory | None = None, memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None, memory_config: MemoryConfig | None = None,
vision_enabled: bool = False, vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]: variable_pool: VariablePool,
inputs = inputs or {} jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) if isinstance(prompt_template, list):
prompt_messages = prompt_transform.get_prompt( # For chat model
prompt_template=prompt_template, prompt_messages.extend(
inputs=inputs, _handle_list_messages(
query=system_query or "", messages=prompt_template,
files=files, context=context,
context=context, jinja2_variables=jinja2_variables,
memory_config=memory_config, variable_pool=variable_pool,
memory=memory, vision_detail_config=vision_detail,
model_config=model_config, )
) )
stop = model_config.stop
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = [] filtered_prompt_messages = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.is_empty(): if isinstance(prompt_message.content, list):
continue
if not isinstance(prompt_message.content, str):
prompt_message_content = [] prompt_message_content = []
for content_item in prompt_message.content or []: for content_item in prompt_message.content:
# Skip image if vision is disabled # Skip content if features are not defined
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue continue
if isinstance(content_item, ImagePromptMessageContent): # Skip content if corresponding feature is not supported
# Override vision config if LLM node has vision config, if (
# cuz vision detail is related to the configuration from FileUpload feature. (
content_item.detail = vision_detail content_item.type == PromptMessageContentType.IMAGE
prompt_message_content.append(content_item) and ModelFeature.VISION not in model_config.model_schema.features
elif isinstance( )
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
): ):
prompt_message_content.append(content_item) continue
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1: if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message) filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages: if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError( raise NoPromptFoundError(
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop
@classmethod @classmethod
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
} }
}, },
} }
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
system_query=query, user_query=query,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
files=files, user_files=files,
vision_enabled=node_data.vision.enabled, vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
) )
# handle invoke result # handle invoke result

61
api/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -932,10 +932,6 @@ files = [
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"},
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
@ -948,14 +944,8 @@ files = [
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"},
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
@ -966,24 +956,8 @@ files = [
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"},
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"},
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"},
{file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"},
{file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"},
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
@ -993,10 +967,6 @@ files = [
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"},
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
@ -1008,10 +978,6 @@ files = [
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"},
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
@ -1024,10 +990,6 @@ files = [
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"},
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
@ -1040,10 +1002,6 @@ files = [
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"},
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
@ -2453,6 +2411,21 @@ files = [
[package.extras] [package.extras]
test = ["pytest (>=6)"] test = ["pytest (>=6)"]
[[package]]
name = "faker"
version = "32.1.0"
description = "Faker is a Python package that generates fake data for you."
optional = false
python-versions = ">=3.8"
files = [
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
]
[package.dependencies]
python-dateutil = ">=2.4"
typing-extensions = "*"
[[package]] [[package]]
name = "fal-client" name = "fal-client"
version = "0.5.6" version = "0.5.6"
@ -11078,4 +11051,4 @@ cffi = ["cffi (>=1.11)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "2ba4b464eebc26598f290fa94713acc44c588f902176e6efa80622911d40f0ac" content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"

View File

@ -267,6 +267,7 @@ weaviate-client = "~3.21.0"
optional = true optional = true
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
coverage = "~7.2.4" coverage = "~7.2.4"
faker = "~32.1.0"
pytest = "~8.3.2" pytest = "~8.3.2"
pytest-benchmark = "~4.0.0" pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3" pytest-env = "~1.1.3"

View File

@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)

View File

@ -4,29 +4,21 @@ import pytest
from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
def test_validate_credentials(): def test_validate_credentials():
model = AzureAIStudioRerankModel() model = AzureRerankModel()
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( model.validate_credentials(
model="azure-ai-studio-rerank-v1", model="azure-ai-studio-rerank-v1",
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
) )
def test_invoke_model(): def test_invoke_model():
model = AzureAIStudioRerankModel() model = AzureRerankModel()
result = model.invoke( result = model.invoke(
model="azure-ai-studio-rerank-v1", model="azure-ai-studio-rerank-v1",

View File

@ -1,125 +1,484 @@
from collections.abc import Sequence
from typing import Optional
import pytest import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
class TestLLMNode: class MockTokenBufferMemory:
@pytest.fixture def __init__(self, history_messages=None):
def llm_node(self): self.history_messages = history_messages or []
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
def test_fetch_files_with_file_segment(self, llm_node): def get_history_prompt_messages(
file = File( self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1", id="1",
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test.jpg", filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE, transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1", related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
) )
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) ]
result = llm_node._fetch_files(selector=["sys", "files"]) fake_query = faker.sentence()
assert result == [file]
def test_fetch_files_with_array_file_segment(self, llm_node): prompt_messages, _ = llm_node._fetch_prompt_messages(
files = [ user_query=fake_query,
File( user_files=files,
id="1", context=None,
tenant_id="test", memory=None,
type=FileType.IMAGE, model_config=model_config,
filename="test1.jpg", prompt_template=prompt_template,
transfer_method=FileTransferMethod.LOCAL_FILE, memory_config=None,
related_id="1", vision_enabled=False,
), vision_detail=fake_vision_detail,
File( variable_pool=llm_node.graph_runtime_state.variable_pool,
id="2", jinja2_variables=[],
tenant_id="test", )
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"]) assert prompt_messages == [UserPromptMessage(content=fake_query)]
assert result == files
def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"]) def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
assert result == [] # Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
def test_fetch_files_with_array_any_segment(self, llm_node): # Generate fake values for prompt template
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
fake_context = faker.sentence()
fake_window_size = faker.random_int(min=1, max=3)
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
result = llm_node._fetch_files(selector=["sys", "files"]) # Setup mock memory with history messages
assert result == [] mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
def test_fetch_files_with_non_existent_variable(self, llm_node): # Setup memory configuration
result = llm_node._fetch_files(selector=["sys", "files"]) memory_config = MemoryConfig(
assert result == [] role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
query_prompt_template=None,
)
memory = MockTokenBufferMemory(history_messages=mock_history)
# Test scenarios covering different file input combinations
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(content=fake_query),
],
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
]
+ mock_history[fake_window_size * -2 :]
+ [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File without vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File with video file and vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.VIDEO,
filename="test1.mp4",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension="mp4",
)
},
),
]
for scenario in test_scenarios:
model_config.model_schema.features = scenario.features
for k, v in scenario.file_variables.items():
selector = k.split(".")
llm_node.graph_runtime_state.variable_pool.add(selector, v)
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=scenario.prompt_template,
memory_config=memory_config,
vision_enabled=scenario.vision_enabled,
vision_detail=scenario.vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
# Verify the result
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"

View File

@ -0,0 +1,25 @@
from collections.abc import Mapping, Sequence
from pydantic import BaseModel, Field
from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

View File

@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
} }
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => { export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
if (fileMimetype)
return mime.getExtension(fileMimetype) || ''
if (isRemote)
return ''
if (fileName) { if (fileName) {
const fileNamePair = fileName.split('.') const fileNamePair = fileName.split('.')
const fileNamePairLength = fileNamePair.length const fileNamePairLength = fileNamePair.length
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
return fileNamePair[fileNamePairLength - 1] return fileNamePair[fileNamePairLength - 1]
} }
if (fileMimetype)
return mime.getExtension(fileMimetype) || ''
if (isRemote)
return ''
return '' return ''
} }

View File

@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
onEditionTypeChange={onEditionTypeChange} onEditionTypeChange={onEditionTypeChange}
varList={varList} varList={varList}
handleAddVariable={handleAddVariable} handleAddVariable={handleAddVariable}
isSupportFileVar
/> />
) )
} }

View File

@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
handleStop, handleStop,
varInputs, varInputs,
runResult, runResult,
filterJinjia2InputVar,
} = useConfig(id, data) } = useConfig(id, data)
const model = inputs.model const model = inputs.model
@ -194,7 +195,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
list={inputs.prompt_config?.jinja2_variables || []} list={inputs.prompt_config?.jinja2_variables || []}
onChange={handleVarListChange} onChange={handleVarListChange}
onVarNameChange={handleVarNameChange} onVarNameChange={handleVarNameChange}
filterVar={filterVar} filterVar={filterJinjia2InputVar}
/> />
</Field> </Field>
)} )}
@ -233,6 +234,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
hasSetBlockStatus={hasSetBlockStatus} hasSetBlockStatus={hasSetBlockStatus}
nodesOutputVars={availableVars} nodesOutputVars={availableVars}
availableNodes={availableNodesWithParent} availableNodes={availableNodesWithParent}
isSupportFileVar
/> />
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (

View File

@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
}, [inputs, setInputs]) }, [inputs, setInputs])
const filterInputVar = useCallback((varPayload: Var) => { const filterInputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
}, [])
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
}, []) }, [])
const filterMemoryPromptVar = useCallback((varPayload: Var) => { const filterMemoryPromptVar = useCallback((varPayload: Var) => {
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
}, []) }, [])
const { const {
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
handleRun, handleRun,
handleStop, handleStop,
runResult, runResult,
filterJinjia2InputVar,
} }
} }