Compare commits

...

32 Commits

Author SHA1 Message Date
-LAN-
abacc3768f Updates poetry.lock content hash for consistency
Changes the content hash in poetry.lock to ensure the lock
file's integrity aligns with the updated project dependencies.
No package versions changed in this update.
2024-11-15 11:52:47 +08:00
-LAN-
e31358219c feat(llm-panel): refine variable filtering logic
Introduce `filterJinjia2InputVar` to enhance variable filtering, specifically excluding `arrayFile` types from Jinja2 input variables. This adjustment improves the management of variable types, aligning with expected input capacities and ensuring more reliable configurations. Additionally, support for file variables is enabled in relevant components, broadening functionality and user options.
2024-11-15 11:51:56 +08:00
-LAN-
4e360ec19a refactor(core): decouple LLMNode prompt handling
Moved prompt handling functions out of the `LLMNode` class to improve modularity and separation of concerns. This refactor allows better reuse and testing of prompt-related functions. Adjusted existing logic to fetch queries and handle context and memory configurations more effectively. Updated tests to align with the new structure and ensure continued functionality.
2024-11-15 11:51:56 +08:00
-LAN-
f68d6bd5e2 refactor(node.py): streamline template rendering
Removed the `_render_basic_message` function and integrated its logic directly into the `LLMNode` class. This reduces redundancy and simplifies the handling of message templates by utilizing `convert_template` more directly. This change enhances code readability and maintainability.
2024-11-15 11:51:56 +08:00
-LAN-
b860a893c8 feat(config-prompt): add support for file variables
Extended the `ConfigPromptItem` component to support file variables by including the `isSupportFileVar` prop. Updated `useConfig` hooks to accept `arrayFile` variable types for both input and memory prompt filtering. This enhancement allows handling of file data types seamlessly, improving flexibility in configuring prompts.
2024-11-15 11:51:56 +08:00
-LAN-
0354c7813e fix(file-manager): enforce file extension presence
Added a check to ensure that files have an extension before processing to avoid potential errors. Updated unit tests to reflect this requirement by including extensions in test data. This prevents exceptions from being raised due to missing file extension information.
2024-11-15 11:51:56 +08:00
-LAN-
94794d892e feat: add support for document, video, and audio content
Expanded the system to handle document types across different modules and introduced video and audio content handling in model features. Adjusted the prompt message logic to conditionally process content based on available features, enhancing flexibility in media processing. Added comprehensive error handling in `LLMNode` for better runtime resilience. Updated YAML configuration and unit tests to reflect these changes.
2024-11-15 11:51:56 +08:00
-LAN-
fb94d0b7cf fix: ensure workflow run persistence before refresh
Adds the workflow run object to the database session to guarantee it is persisted prior to refreshing its state. This change resolves potential issues with data consistency and integrity when the workflow run is accessed after operations. References issue #123 for more context.
2024-11-15 11:51:56 +08:00
-LAN-
02c39b2631 fix(file-uploader): resolve file extension logic order
Rearranged the logic in `getFileExtension` to first check for a valid `fileName` before considering `fileMimetype` or `isRemote`. This change ensures that the function prioritizes extracting extensions from file names directly, improving accuracy and handling edge cases more effectively. This update may prevent incorrect file extensions when mimetype is prioritized incorrectly.

Resolves #123.
2024-11-15 11:51:56 +08:00
-LAN-
87f78ff582 feat: enhance image handling in prompt processing
Updated image processing logic to check for model support of vision features, preventing errors when handling images with models that do not support them. Added a test scenario to validate behavior when vision features are absent. This ensures robust image handling and avoids unexpected behavior during image-related prompts.
2024-11-15 11:51:56 +08:00
-LAN-
6872b32c7d fix(node): handle empty text segments gracefully
Ensure that messages are only created from non-empty text segments, preventing potential issues with empty content.

test: add scenario for file variable handling

Introduce a test case for scenarios involving prompt templates with file variables, particularly images, to improve reliability and test coverage. Updated `LLMNodeTestScenario` to use `Sequence` and `Mapping` for more flexible configurations.

Closes #123, relates to #456.
2024-11-15 11:51:56 +08:00
-LAN-
97fab7649b feat(tests): refactor LLMNode tests for clarity
Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach.
2024-11-15 11:51:56 +08:00
-LAN-
9f0f82cb1c refactor(tests): streamline LLM node prompt message tests
Refactored LLM node tests to enhance clarity and maintainability by creating test scenarios for different file input combinations. This restructuring replaces repetitive code with a more concise approach, improving test coverage and readability.

No functional code changes were made.

References: #123, #456
2024-11-15 11:51:56 +08:00
-LAN-
ef08abafdf Simplify test setup in LLM node tests
Replaced redundant variables in test setup to streamline and align usage of fake data, enhancing readability and maintainability. Adjusted image URL variables to utilize consistent references, ensuring uniformity across test configurations. Also, corrected context variable naming for clarity. No functional impact, purely a refactor for code clarity.
2024-11-15 11:51:56 +08:00
-LAN-
d6c9ab8554 feat(llm_node): allow to use image file directly in the prompt. 2024-11-15 11:51:56 +08:00
-LAN-
bab989e3b3 Remove unnecessary data from log and text properties
Updated the log and text properties in segments to return
empty strings instead of the segment value. This change
prevents potential leakage of sensitive data by ensuring
only non-sensitive information is logged or transformed
into text. Addresses potential security and privacy concerns.
2024-11-15 11:51:56 +08:00
-LAN-
ddc86503dc refactor(model_manager): update parameter type for flexibility
- Changed 'prompt_messages' parameter from list to Sequence for broader input type compatibility.
2024-11-15 11:51:56 +08:00
-LAN-
abad35f700 refactor(memory): use Sequence instead of list for prompt messages
- Improved flexibility by using Sequence instead of list, allowing for broader compatibility with different types of sequences.
- Helps future-proof the method signature by leveraging the more generic Sequence type.
2024-11-15 11:51:56 +08:00
-LAN-
620b0e69f5 fix(dependencies): update Faker version constraint
- Changed the Faker version from caret constraint to tilde constraint for compatibility.
- Updated poetry.lock for changes in pyproject.toml content.
2024-11-15 11:51:56 +08:00
-LAN-
71cf4c7dbf chore(config): remove unnecessary 'frozen' parameter for test
- Simplified app configuration by removing the 'frozen' parameter since it is no longer needed.
- Ensures more flexible handling of config attributes.
2024-11-15 11:51:20 +08:00
-LAN-
47e8a5d4d1 refactor(model_runtime): use Sequence for content in PromptMessage
- Replaced list with Sequence for more flexible content type.
- Improved type consistency by importing from collections.abc.
2024-11-15 11:51:20 +08:00
-LAN-
93bbb194f2 refactor(prompt): enhance type flexibility for prompt messages
- Changed input type from list to Sequence for prompt messages to allow more flexible input types.
- Improved compatibility with functions expecting different iterable types.
2024-11-15 11:51:20 +08:00
-LAN-
2106fc5266 fix(tests): update Azure Rerank Model usage and clean imports 2024-11-15 11:51:20 +08:00
-LAN-
229b146525 feat(errors): add new error classes for unsupported prompt types and memory role prefix requirements 2024-11-15 11:51:20 +08:00
-LAN-
d9fa6f79be refactor: update jinja2_variables and prompt_config to use Sequence and add validators for None handling 2024-11-15 11:51:20 +08:00
-LAN-
4f89214d89 refactor: update stop parameter type to use Sequence instead of list 2024-11-15 11:51:20 +08:00
-LAN-
1fdaea29aa refactor(converter): simplify model credentials validation logic 2024-11-15 11:51:20 +08:00
-LAN-
1397d0000d chore(deps): add faker 2024-11-15 11:51:20 +08:00
非法操作
4b2abf8ac2
fix: create_blob_message of tool will always create image type file (#10701)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
2024-11-15 10:38:12 +08:00
Bowen Liang
365cb4b368
chore(lint): bump ruff from 0.6.9 to 0.7.3 (#10714) 2024-11-15 09:19:41 +08:00
GeorgeCaoJ
c85bff235d
fix(i18n): handle key naming error (#10713) 2024-11-15 09:01:38 +08:00
Kalo Chin
ad16180b1a
feat(tool): fal ai wizper ASR built-in tool (#10716) 2024-11-15 09:01:07 +08:00
36 changed files with 1579 additions and 256 deletions

View File

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra="ignore",
)

View File

@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@ -38,12 +38,8 @@ class ModelConfigConverter:
)
if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
@ -76,7 +72,7 @@ class ModelConfigConverter:
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema:
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity(

View File

@ -217,6 +217,7 @@ class WorkflowCycleManage:
).total_seconds()
db.session.commit()
db.session.add(workflow_run)
db.session.refresh(workflow_run)
db.session.close()

View File

@ -74,6 +74,8 @@ def to_prompt_message_content(
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _:
raise ValueError("file type f.type is not supported")

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -27,7 +28,7 @@ class TokenBufferMemory:
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit

View File

@ -100,10 +100,10 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@ -1,4 +1,5 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum
from typing import Optional
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel):
@ -107,7 +109,7 @@ class PromptMessage(ABC, BaseModel):
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None
def is_empty(self) -> bool:

View File

@ -87,6 +87,9 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum):

View File

@ -2,7 +2,7 @@ import logging
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -212,7 +212,7 @@ if you are not sure about the structure.
)
model_parameters.pop("response_format")
stop = stop or []
stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
@ -408,7 +408,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -479,7 +479,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -601,7 +601,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -647,7 +647,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -694,7 +694,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@ -742,7 +742,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,

View File

@ -8,6 +8,7 @@ features:
- agent-thought
- stream-tool-call
- vision
- audio
model_properties:
mode: chat
context_size: 128000

View File

@ -1,5 +1,6 @@
import json
import random
from collections import UserDict
from datetime import datetime
@ -10,9 +11,9 @@ class ChatRole:
FUNCTION = "function"
class _Dict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
class _Dict(UserDict):
__setattr__ = UserDict.__setitem__
__getattr__ = UserDict.__getitem__
def __missing__(self, key):
return None

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import cast
from core.model_runtime.entities import (
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:
@staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode

View File

@ -0,0 +1,52 @@
import io
import os
from typing import Any
import fal_client
from core.file.enums import FileAttribute, FileType
from core.file.file_manager import download, get_attr
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class WizperTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
audio_file = tool_parameters.get("audio_file")
task = tool_parameters.get("task", "transcribe")
language = tool_parameters.get("language", "en")
chunk_level = tool_parameters.get("chunk_level", "segment")
version = tool_parameters.get("version", "3")
if audio_file.type != FileType.AUDIO:
return [self.create_text_message("Not a valid audio file.")]
api_key = self.runtime.credentials["fal_api_key"]
os.environ["FAL_KEY"] = api_key
audio_binary = io.BytesIO(download(audio_file))
mime_type = get_attr(file=audio_file, attr=FileAttribute.MIME_TYPE)
file_data = audio_binary.getvalue()
try:
audio_url = fal_client.upload(file_data, mime_type)
except Exception as e:
return [self.create_text_message(f"Error uploading audio file: {str(e)}")]
arguments = {
"audio_url": audio_url,
"task": task,
"language": language,
"chunk_level": chunk_level,
"version": version,
}
result = fal_client.subscribe(
"fal-ai/wizper",
arguments=arguments,
with_logs=False,
)
return self.create_json_message(result)

View File

@ -0,0 +1,489 @@
identity:
name: wizper
author: Kalo Chin
label:
en_US: Wizper
zh_Hans: Wizper
description:
human:
en_US: Transcribe an audio file using the Whisper model.
zh_Hans: 使用 Whisper 模型转录音频文件。
llm: Transcribe an audio file using the Whisper model.
parameters:
- name: audio_file
type: file
required: true
label:
en_US: Audio File
zh_Hans: 音频文件
human_description:
en_US: "Upload an audio file to transcribe. Supports mp3, mp4, mpeg, mpga, m4a, wav, or webm formats."
zh_Hans: "上传要转录的音频文件。支持 mp3、mp4、mpeg、mpga、m4a、wav 或 webm 格式。"
llm_description: "Audio file to transcribe. Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm."
form: llm
- name: task
type: select
required: true
label:
en_US: Task
zh_Hans: 任务
human_description:
en_US: "Choose whether to transcribe the audio in its original language or translate it to English"
zh_Hans: "选择是以原始语言转录音频还是将其翻译成英语"
llm_description: "Task to perform on the audio file. Either transcribe or translate. Default value: 'transcribe'. If 'translate' is selected as the task, the audio will be translated to English, regardless of the language selected."
form: form
default: transcribe
options:
- value: transcribe
label:
en_US: Transcribe
zh_Hans: 转录
- value: translate
label:
en_US: Translate
zh_Hans: 翻译
- name: language
type: select
required: true
label:
en_US: Language
zh_Hans: 语言
human_description:
en_US: "Select the primary language spoken in the audio file"
zh_Hans: "选择音频文件中使用的主要语言"
llm_description: "Language of the audio file."
form: form
default: en
options:
- value: af
label:
en_US: Afrikaans
zh_Hans: 南非语
- value: am
label:
en_US: Amharic
zh_Hans: 阿姆哈拉语
- value: ar
label:
en_US: Arabic
zh_Hans: 阿拉伯语
- value: as
label:
en_US: Assamese
zh_Hans: 阿萨姆语
- value: az
label:
en_US: Azerbaijani
zh_Hans: 阿塞拜疆语
- value: ba
label:
en_US: Bashkir
zh_Hans: 巴什基尔语
- value: be
label:
en_US: Belarusian
zh_Hans: 白俄罗斯语
- value: bg
label:
en_US: Bulgarian
zh_Hans: 保加利亚语
- value: bn
label:
en_US: Bengali
zh_Hans: 孟加拉语
- value: bo
label:
en_US: Tibetan
zh_Hans: 藏语
- value: br
label:
en_US: Breton
zh_Hans: 布列塔尼语
- value: bs
label:
en_US: Bosnian
zh_Hans: 波斯尼亚语
- value: ca
label:
en_US: Catalan
zh_Hans: 加泰罗尼亚语
- value: cs
label:
en_US: Czech
zh_Hans: 捷克语
- value: cy
label:
en_US: Welsh
zh_Hans: 威尔士语
- value: da
label:
en_US: Danish
zh_Hans: 丹麦语
- value: de
label:
en_US: German
zh_Hans: 德语
- value: el
label:
en_US: Greek
zh_Hans: 希腊语
- value: en
label:
en_US: English
zh_Hans: 英语
- value: es
label:
en_US: Spanish
zh_Hans: 西班牙语
- value: et
label:
en_US: Estonian
zh_Hans: 爱沙尼亚语
- value: eu
label:
en_US: Basque
zh_Hans: 巴斯克语
- value: fa
label:
en_US: Persian
zh_Hans: 波斯语
- value: fi
label:
en_US: Finnish
zh_Hans: 芬兰语
- value: fo
label:
en_US: Faroese
zh_Hans: 法罗语
- value: fr
label:
en_US: French
zh_Hans: 法语
- value: gl
label:
en_US: Galician
zh_Hans: 加利西亚语
- value: gu
label:
en_US: Gujarati
zh_Hans: 古吉拉特语
- value: ha
label:
en_US: Hausa
zh_Hans: 毫萨语
- value: haw
label:
en_US: Hawaiian
zh_Hans: 夏威夷语
- value: he
label:
en_US: Hebrew
zh_Hans: 希伯来语
- value: hi
label:
en_US: Hindi
zh_Hans: 印地语
- value: hr
label:
en_US: Croatian
zh_Hans: 克罗地亚语
- value: ht
label:
en_US: Haitian Creole
zh_Hans: 海地克里奥尔语
- value: hu
label:
en_US: Hungarian
zh_Hans: 匈牙利语
- value: hy
label:
en_US: Armenian
zh_Hans: 亚美尼亚语
- value: id
label:
en_US: Indonesian
zh_Hans: 印度尼西亚语
- value: is
label:
en_US: Icelandic
zh_Hans: 冰岛语
- value: it
label:
en_US: Italian
zh_Hans: 意大利语
- value: ja
label:
en_US: Japanese
zh_Hans: 日语
- value: jw
label:
en_US: Javanese
zh_Hans: 爪哇语
- value: ka
label:
en_US: Georgian
zh_Hans: 格鲁吉亚语
- value: kk
label:
en_US: Kazakh
zh_Hans: 哈萨克语
- value: km
label:
en_US: Khmer
zh_Hans: 高棉语
- value: kn
label:
en_US: Kannada
zh_Hans: 卡纳达语
- value: ko
label:
en_US: Korean
zh_Hans: 韩语
- value: la
label:
en_US: Latin
zh_Hans: 拉丁语
- value: lb
label:
en_US: Luxembourgish
zh_Hans: 卢森堡语
- value: ln
label:
en_US: Lingala
zh_Hans: 林加拉语
- value: lo
label:
en_US: Lao
zh_Hans: 老挝语
- value: lt
label:
en_US: Lithuanian
zh_Hans: 立陶宛语
- value: lv
label:
en_US: Latvian
zh_Hans: 拉脱维亚语
- value: mg
label:
en_US: Malagasy
zh_Hans: 马尔加什语
- value: mi
label:
en_US: Maori
zh_Hans: 毛利语
- value: mk
label:
en_US: Macedonian
zh_Hans: 马其顿语
- value: ml
label:
en_US: Malayalam
zh_Hans: 马拉雅拉姆语
- value: mn
label:
en_US: Mongolian
zh_Hans: 蒙古语
- value: mr
label:
en_US: Marathi
zh_Hans: 马拉地语
- value: ms
label:
en_US: Malay
zh_Hans: 马来语
- value: mt
label:
en_US: Maltese
zh_Hans: 马耳他语
- value: my
label:
en_US: Burmese
zh_Hans: 缅甸语
- value: ne
label:
en_US: Nepali
zh_Hans: 尼泊尔语
- value: nl
label:
en_US: Dutch
zh_Hans: 荷兰语
- value: nn
label:
en_US: Norwegian Nynorsk
zh_Hans: 新挪威语
- value: no
label:
en_US: Norwegian
zh_Hans: 挪威语
- value: oc
label:
en_US: Occitan
zh_Hans: 奥克语
- value: pa
label:
en_US: Punjabi
zh_Hans: 旁遮普语
- value: pl
label:
en_US: Polish
zh_Hans: 波兰语
- value: ps
label:
en_US: Pashto
zh_Hans: 普什图语
- value: pt
label:
en_US: Portuguese
zh_Hans: 葡萄牙语
- value: ro
label:
en_US: Romanian
zh_Hans: 罗马尼亚语
- value: ru
label:
en_US: Russian
zh_Hans: 俄语
- value: sa
label:
en_US: Sanskrit
zh_Hans: 梵语
- value: sd
label:
en_US: Sindhi
zh_Hans: 信德语
- value: si
label:
en_US: Sinhala
zh_Hans: 僧伽罗语
- value: sk
label:
en_US: Slovak
zh_Hans: 斯洛伐克语
- value: sl
label:
en_US: Slovenian
zh_Hans: 斯洛文尼亚语
- value: sn
label:
en_US: Shona
zh_Hans: 修纳语
- value: so
label:
en_US: Somali
zh_Hans: 索马里语
- value: sq
label:
en_US: Albanian
zh_Hans: 阿尔巴尼亚语
- value: sr
label:
en_US: Serbian
zh_Hans: 塞尔维亚语
- value: su
label:
en_US: Sundanese
zh_Hans: 巽他语
- value: sv
label:
en_US: Swedish
zh_Hans: 瑞典语
- value: sw
label:
en_US: Swahili
zh_Hans: 斯瓦希里语
- value: ta
label:
en_US: Tamil
zh_Hans: 泰米尔语
- value: te
label:
en_US: Telugu
zh_Hans: 泰卢固语
- value: tg
label:
en_US: Tajik
zh_Hans: 塔吉克语
- value: th
label:
en_US: Thai
zh_Hans: 泰语
- value: tk
label:
en_US: Turkmen
zh_Hans: 土库曼语
- value: tl
label:
en_US: Tagalog
zh_Hans: 他加禄语
- value: tr
label:
en_US: Turkish
zh_Hans: 土耳其语
- value: tt
label:
en_US: Tatar
zh_Hans: 鞑靼语
- value: uk
label:
en_US: Ukrainian
zh_Hans: 乌克兰语
- value: ur
label:
en_US: Urdu
zh_Hans: 乌尔都语
- value: uz
label:
en_US: Uzbek
zh_Hans: 乌兹别克语
- value: vi
label:
en_US: Vietnamese
zh_Hans: 越南语
- value: yi
label:
en_US: Yiddish
zh_Hans: 意第绪语
- value: yo
label:
en_US: Yoruba
zh_Hans: 约鲁巴语
- value: yue
label:
en_US: Cantonese
zh_Hans: 粤语
- value: zh
label:
en_US: Chinese
zh_Hans: 中文
- name: chunk_level
type: select
label:
en_US: Chunk Level
zh_Hans: 分块级别
human_description:
en_US: "Choose how the transcription should be divided into chunks"
zh_Hans: "选择如何将转录内容分成块"
llm_description: "Level of the chunks to return."
form: form
default: segment
options:
- value: segment
label:
en_US: Segment
zh_Hans:
- name: version
type: select
label:
en_US: Version
zh_Hans: 版本
human_description:
en_US: "Select which version of the Whisper large model to use"
zh_Hans: "选择要使用的 Whisper large 模型版本"
llm_description: "Version of the model to use. All of the models are the Whisper large variant."
form: form
default: "3"
options:
- value: "3"
label:
en_US: Version 3
zh_Hans: 版本 3

View File

@ -118,11 +118,11 @@ class FileSegment(Segment):
@property
def log(self) -> str:
return str(self.value)
return ""
@property
def text(self) -> str:
return str(self.value)
return ""
class ArrayAnySegment(ArraySegment):
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
for item in self.value:
items.append(item.markdown)
return "\n".join(items)
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""

View File

@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
class PromptConfig(BaseModel):
jinja2_variables: Optional[list[VariableSelector]] = None
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
class LLMNodeChatModelMessage(ChatModelMessage):
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v

View File

@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""
class NotSupportedPromptTypeError(LLMNodeError):
"""Raised when the prompt type is not supported."""
class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""

View File

@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import (
@ -32,8 +38,9 @@ from core.variables import (
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
@ -62,14 +69,18 @@ from .exc import (
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch prompt messages
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise VariableNotFoundError("Query not found")
query = query.text
query = self.node_data.memory.query_prompt_template
else:
query = None
prompt_messages, stop = self._fetch_prompt_messages(
system_query=query,
inputs=inputs,
files=files,
user_query=query,
user_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)
process_data = {
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run: {e}")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
user_query: str | None = None,
user_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = inputs or {}
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=system_query or "",
files=files,
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
messages=prompt_template,
context=context,
memory_config=memory_config,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
stop = model_config.stop
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content or []:
# Skip image if vision is disabled
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
continue
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
stop = model_config.stop
return filtered_prompt_messages, stop
@classmethod
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
}
},
}
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
system_query=query,
user_query=query,
memory=memory,
model_config=model_config,
files=files,
user_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
# handle invoke result

View File

@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from os import path
from typing import Any
from sqlalchemy import select
@ -180,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
@ -202,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
@ -211,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
raise ValueError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
@ -228,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
mapping = {
"tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": transfer_method,
"url": url,
}

View File

@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
return mime_type, filename, file_size
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
@ -199,12 +213,13 @@ def _build_from_tool_file(
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
return File(
id=mapping.get("id"),
tenant_id=tenant_id,
filename=tool_file.name,
type=FileType.value_of(mapping.get("type")),
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
related_id=tool_file.id,

86
api/poetry.lock generated
View File

@ -2411,6 +2411,41 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "faker"
version = "32.1.0"
description = "Faker is a Python package that generates fake data for you."
optional = false
python-versions = ">=3.8"
files = [
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
]
[package.dependencies]
python-dateutil = ">=2.4"
typing-extensions = "*"
[[package]]
name = "fal-client"
version = "0.5.6"
description = "Python client for fal.ai"
optional = false
python-versions = ">=3.8"
files = [
{file = "fal_client-0.5.6-py3-none-any.whl", hash = "sha256:631fd857a3c44753ee46a2eea1e7276471453aca58faac9c3702f744c7c84050"},
{file = "fal_client-0.5.6.tar.gz", hash = "sha256:d3afc4b6250023d0ee8437ec504558231d3b106d7aabc12cda8c39883faddecb"},
]
[package.dependencies]
httpx = ">=0.21.0,<1"
httpx-sse = ">=0.4.0,<0.5"
[package.extras]
dev = ["fal-client[docs,test]"]
docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme"]
test = ["pillow", "pytest", "pytest-asyncio"]
[[package]]
name = "fastapi"
version = "0.115.4"
@ -4049,6 +4084,17 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.0"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
]
[[package]]
name = "huggingface-hub"
version = "0.16.4"
@ -8466,29 +8512,29 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.6.9"
version = "0.7.3"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"},
{file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"},
{file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"},
{file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"},
{file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"},
{file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"},
{file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"},
{file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"},
{file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"},
{file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"},
{file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"},
{file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"},
{file = "ruff-0.7.3-py3-none-linux_armv6l.whl", hash = "sha256:34f2339dc22687ec7e7002792d1f50712bf84a13d5152e75712ac08be565d344"},
{file = "ruff-0.7.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:fb397332a1879b9764a3455a0bb1087bda876c2db8aca3a3cbb67b3dbce8cda0"},
{file = "ruff-0.7.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:37d0b619546103274e7f62643d14e1adcbccb242efda4e4bdb9544d7764782e9"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d59f0c3ee4d1a6787614e7135b72e21024875266101142a09a61439cb6e38a5"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:44eb93c2499a169d49fafd07bc62ac89b1bc800b197e50ff4633aed212569299"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d0242ce53f3a576c35ee32d907475a8d569944c0407f91d207c8af5be5dae4e"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6b6224af8b5e09772c2ecb8dc9f3f344c1aa48201c7f07e7315367f6dd90ac29"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c50f95a82b94421c964fae4c27c0242890a20fe67d203d127e84fbb8013855f5"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f3eff9961b5d2644bcf1616c606e93baa2d6b349e8aa8b035f654df252c8c67"},
{file = "ruff-0.7.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8963cab06d130c4df2fd52c84e9f10d297826d2e8169ae0c798b6221be1d1d2"},
{file = "ruff-0.7.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:61b46049d6edc0e4317fb14b33bd693245281a3007288b68a3f5b74a22a0746d"},
{file = "ruff-0.7.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:10ebce7696afe4644e8c1a23b3cf8c0f2193a310c18387c06e583ae9ef284de2"},
{file = "ruff-0.7.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3f36d56326b3aef8eeee150b700e519880d1aab92f471eefdef656fd57492aa2"},
{file = "ruff-0.7.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5d024301109a0007b78d57ab0ba190087b43dce852e552734ebf0b0b85e4fb16"},
{file = "ruff-0.7.3-py3-none-win32.whl", hash = "sha256:4ba81a5f0c5478aa61674c5a2194de8b02652f17addf8dfc40c8937e6e7d79fc"},
{file = "ruff-0.7.3-py3-none-win_amd64.whl", hash = "sha256:588a9ff2fecf01025ed065fe28809cd5a53b43505f48b69a1ac7707b1b7e4088"},
{file = "ruff-0.7.3-py3-none-win_arm64.whl", hash = "sha256:1713e2c5545863cdbfe2cbce21f69ffaf37b813bfd1fb3b90dc9a6f1963f5a8c"},
{file = "ruff-0.7.3.tar.gz", hash = "sha256:e1d1ba2e40b6e71a61b063354d04be669ab0d39c352461f3d789cac68b54a313"},
]
[[package]]
@ -11005,4 +11051,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "f20bd678044926913dbbc24bd0cf22503a75817aa55f59457ff7822032139b77"
content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"

View File

@ -122,6 +122,7 @@ celery = "~5.4.0"
chardet = "~5.1.0"
cohere = "~5.2.4"
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
fal-client = "0.5.6"
flask = "~3.0.1"
flask-compress = "~1.14"
flask-cors = "~4.0.0"
@ -265,6 +266,7 @@ weaviate-client = "~3.21.0"
optional = true
[tool.poetry.group.dev.dependencies]
coverage = "~7.2.4"
faker = "~32.1.0"
pytest = "~8.3.2"
pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3"
@ -278,4 +280,4 @@ pytest-mock = "~3.14.0"
optional = true
[tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0"
ruff = "~0.6.9"
ruff = "~0.7.3"

View File

@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)

View File

@ -4,29 +4,21 @@ import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
def test_validate_credentials():
model = AzureAIStudioRerankModel()
model = AzureRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="azure-ai-studio-rerank-v1",
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
def test_invoke_model():
model = AzureAIStudioRerankModel()
model = AzureRerankModel()
result = model.invoke(
model="azure-ai-studio-rerank-v1",

View File

@ -1,4 +1,5 @@
import os
from collections import UserDict
from unittest.mock import MagicMock
import pytest
@ -11,7 +12,7 @@ from pymochow.model.table import Table
from requests.adapters import HTTPAdapter
class AttrDict(dict):
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)

View File

@ -1,4 +1,5 @@
import os
from collections import UserDict
from typing import Optional
import pytest
@ -50,7 +51,7 @@ class MockIndex:
return AttrDict({"dimension": 1024})
class AttrDict(dict):
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)

View File

@ -1,22 +1,61 @@
from collections.abc import Sequence
from typing import Optional
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
class MockTokenBufferMemory:
def __init__(self, history_messages=None):
self.history_messages = history_messages or []
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
class TestLLMNode:
@pytest.fixture
def llm_node(self):
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
@ -70,7 +109,48 @@ class TestLLMNode:
)
return node
def test_fetch_files_with_file_segment(self, llm_node):
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
@ -84,7 +164,8 @@ class TestLLMNode:
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(self, llm_node):
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1",
@ -108,18 +189,296 @@ class TestLLMNode:
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(self, llm_node):
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(self, llm_node):
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(self, llm_node):
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
]
fake_query = faker.sentence()
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=None,
memory=None,
model_config=model_config,
prompt_template=prompt_template,
memory_config=None,
vision_enabled=False,
vision_detail=fake_vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
assert prompt_messages == [UserPromptMessage(content=fake_query)]
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
fake_context = faker.sentence()
fake_window_size = faker.random_int(min=1, max=3)
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]
# Setup memory configuration
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
query_prompt_template=None,
)
memory = MockTokenBufferMemory(history_messages=mock_history)
# Test scenarios covering different file input combinations
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(content=fake_query),
],
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
]
+ mock_history[fake_window_size * -2 :]
+ [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File without vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File with video file and vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.VIDEO,
filename="test1.mp4",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension="mp4",
)
},
),
]
for scenario in test_scenarios:
model_config.model_schema.features = scenario.features
for k, v in scenario.file_variables.items():
selector = k.split(".")
llm_node.graph_runtime_state.variable_pool.add(selector, v)
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=scenario.prompt_template,
memory_config=memory_config,
vision_enabled=scenario.vision_enabled,
vision_detail=scenario.vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
# Verify the result
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"

View File

@ -0,0 +1,25 @@
from collections.abc import Mapping, Sequence
from pydantic import BaseModel, Field
from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

View File

@ -1,4 +1,5 @@
import os
from collections import UserDict
from unittest.mock import MagicMock
import pytest
@ -14,7 +15,7 @@ from tests.unit_tests.oss.__mock.base import (
)
class AttrDict(dict):
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)

View File

@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
}
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
if (fileMimetype)
return mime.getExtension(fileMimetype) || ''
if (isRemote)
return ''
if (fileName) {
const fileNamePair = fileName.split('.')
const fileNamePairLength = fileNamePair.length
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
return fileNamePair[fileNamePairLength - 1]
}
if (fileMimetype)
return mime.getExtension(fileMimetype) || ''
if (isRemote)
return ''
return ''
}

View File

@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
onEditionTypeChange={onEditionTypeChange}
varList={varList}
handleAddVariable={handleAddVariable}
isSupportFileVar
/>
)
}

View File

@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
handleStop,
varInputs,
runResult,
filterJinjia2InputVar,
} = useConfig(id, data)
const model = inputs.model
@ -194,7 +195,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
list={inputs.prompt_config?.jinja2_variables || []}
onChange={handleVarListChange}
onVarNameChange={handleVarNameChange}
filterVar={filterVar}
filterVar={filterJinjia2InputVar}
/>
</Field>
)}
@ -233,6 +234,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
hasSetBlockStatus={hasSetBlockStatus}
nodesOutputVars={availableVars}
availableNodes={availableNodesWithParent}
isSupportFileVar
/>
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (

View File

@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
}, [inputs, setInputs])
const filterInputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
}, [])
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
}, [])
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
}, [])
const {
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
handleRun,
handleStop,
runResult,
filterJinjia2InputVar,
}
}

View File

@ -86,8 +86,8 @@ const translation = {
agenteLogDetail: {
agentMode: 'Modo Agente',
toolUsed: 'Ferramenta usada',
iterações: 'Iterações',
iteração: 'Iteração',
iterations: 'Iterações',
iteration: 'Iteração',
finalProcessing: 'Processamento Final',
},
agentLogDetail: {