mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge abacc3768f
into d05fee1182
This commit is contained in:
commit
85deb8b31c
|
@ -27,7 +27,6 @@ class DifyConfig(
|
||||||
# read from dotenv format config file
|
# read from dotenv format config file
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
frozen=True,
|
|
||||||
# ignore extra attributes
|
# ignore extra attributes
|
||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
||||||
|
|
||||||
class ModelConfigConverter:
|
class ModelConfigConverter:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||||
"""
|
"""
|
||||||
Convert app model config dict to entity.
|
Convert app model config dict to entity.
|
||||||
:param app_config: app config
|
:param app_config: app config
|
||||||
|
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_credentials is None:
|
if model_credentials is None:
|
||||||
if not skip_check:
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
||||||
else:
|
|
||||||
model_credentials = {}
|
|
||||||
|
|
||||||
if not skip_check:
|
# check model
|
||||||
# check model
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
model=model_config.model, model_type=ModelType.LLM
|
||||||
model=model_config.model, model_type=ModelType.LLM
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
model_name = model_config.model
|
model_name = model_config.model
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = model_config.parameters
|
completion_params = model_config.parameters
|
||||||
|
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
||||||
|
|
||||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||||
|
|
||||||
if not skip_check and not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
return ModelConfigWithCredentialsEntity(
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
|
|
@ -217,6 +217,7 @@ class WorkflowCycleManage:
|
||||||
).total_seconds()
|
).total_seconds()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
db.session.add(workflow_run)
|
||||||
db.session.refresh(workflow_run)
|
db.session.refresh(workflow_run)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,8 @@ def to_prompt_message_content(
|
||||||
data = _to_url(f)
|
data = _to_url(f)
|
||||||
else:
|
else:
|
||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
|
if f.extension is None:
|
||||||
|
raise ValueError("Missing file extension")
|
||||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
case _:
|
case _:
|
||||||
raise ValueError("file type f.type is not supported")
|
raise ValueError("file type f.type is not supported")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
@ -27,7 +28,7 @@ class TokenBufferMemory:
|
||||||
|
|
||||||
def get_history_prompt_messages(
|
def get_history_prompt_messages(
|
||||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
) -> list[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get history prompt messages.
|
Get history prompt messages.
|
||||||
:param max_token_limit: max token limit
|
:param max_token_limit: max token limit
|
||||||
|
|
|
@ -100,10 +100,10 @@ class ModelInstance:
|
||||||
|
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
self,
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
@ -31,7 +32,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -60,7 +61,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -90,7 +91,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -120,7 +121,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
VIDEO = "video"
|
VIDEO = "video"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(BaseModel):
|
||||||
|
@ -107,7 +109,7 @@ class PromptMessage(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
content: Optional[str | list[PromptMessageContent]] = None
|
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
|
|
|
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
||||||
AGENT_THOUGHT = "agent-thought"
|
AGENT_THOUGHT = "agent-thought"
|
||||||
VISION = "vision"
|
VISION = "vision"
|
||||||
STREAM_TOOL_CALL = "stream-tool-call"
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(str, Enum):
|
class DefaultParameterName(str, Enum):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
||||||
)
|
)
|
||||||
|
|
||||||
model_parameters.pop("response_format")
|
model_parameters.pop("response_format")
|
||||||
stop = stop or []
|
stop = list(stop) if stop is not None else []
|
||||||
stop.extend(["\n```", "```\n"])
|
stop.extend(["\n```", "```\n"])
|
||||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
|
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
|
|
@ -8,6 +8,7 @@ features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- vision
|
- vision
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
|
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||||
|
|
||||||
class PromptMessageUtil:
|
class PromptMessageUtil:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Prompt messages to prompt for saving.
|
Prompt messages to prompt for saving.
|
||||||
:param model_mode: model mode
|
:param model_mode: model mode
|
||||||
|
|
|
@ -118,11 +118,11 @@ class FileSegment(Segment):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log(self) -> str:
|
def log(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
|
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
|
||||||
for item in self.value:
|
for item in self.value:
|
||||||
items.append(item.markdown)
|
items.append(item.markdown)
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
|
@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PromptConfig(BaseModel):
|
class PromptConfig(BaseModel):
|
||||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("jinja2_variables", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_jinja2_variables(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||||
|
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||||
class LLMNodeData(BaseNodeData):
|
class LLMNodeData(BaseNodeData):
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
prompt_config: Optional[PromptConfig] = None
|
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
|
||||||
|
@field_validator("prompt_config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_prompt_config(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return PromptConfig()
|
||||||
|
return v
|
||||||
|
|
|
@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
|
||||||
|
|
||||||
class NoPromptFoundError(LLMNodeError):
|
class NoPromptFoundError(LLMNodeError):
|
||||||
"""Raised when no prompt is found in the LLM configuration."""
|
"""Raised when no prompt is found in the LLM configuration."""
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedPromptTypeError(LLMNodeError):
|
||||||
|
"""Raised when the prompt type is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||||
|
"""Raised when memory role prefix is required for completion model."""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
|
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
|
from core.file import FileType, file_manager
|
||||||
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AudioPromptMessageContent,
|
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
VideoPromptMessageContent,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
|
@ -32,8 +38,9 @@ from core.variables import (
|
||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
)
|
)
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
@ -62,14 +69,18 @@ from .exc import (
|
||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMModeRequiredError,
|
LLMModeRequiredError,
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
|
MemoryRolePrefixRequiredError,
|
||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
NoPromptFoundError,
|
NoPromptFoundError,
|
||||||
|
NotSupportedPromptTypeError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(BaseNode[LLMNodeData]):
|
class LLMNode(BaseNode[LLMNodeData]):
|
||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
|
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
|
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
if self.node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query = self.node_data.memory.query_prompt_template
|
||||||
if not query:
|
|
||||||
raise VariableNotFoundError("Query not found")
|
|
||||||
query = query.text
|
|
||||||
else:
|
else:
|
||||||
query = None
|
query = None
|
||||||
|
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
system_query=query,
|
user_query=query,
|
||||||
inputs=inputs,
|
user_files=files,
|
||||||
files=files,
|
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
memory_config=self.node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self.node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self.node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = {
|
process_data = {
|
||||||
|
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Node {self.node_id} failed to run: {e}")
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=str(e),
|
||||||
|
inputs=node_inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
|
|
||||||
|
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
self,
|
self,
|
||||||
node_data_model: ModelConfig,
|
node_data_model: ModelConfig,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
) -> Generator[NodeEvent, None, None]:
|
) -> Generator[NodeEvent, None, None]:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
def _fetch_prompt_messages(
|
def _fetch_prompt_messages(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
system_query: str | None = None,
|
user_query: str | None = None,
|
||||||
inputs: dict[str, str] | None = None,
|
user_files: Sequence["File"],
|
||||||
files: Sequence["File"],
|
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
memory_config: MemoryConfig | None = None,
|
memory_config: MemoryConfig | None = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
variable_pool: VariablePool,
|
||||||
inputs = inputs or {}
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||||
|
prompt_messages = []
|
||||||
|
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
if isinstance(prompt_template, list):
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
# For chat model
|
||||||
prompt_template=prompt_template,
|
prompt_messages.extend(
|
||||||
inputs=inputs,
|
_handle_list_messages(
|
||||||
query=system_query or "",
|
messages=prompt_template,
|
||||||
files=files,
|
context=context,
|
||||||
context=context,
|
jinja2_variables=jinja2_variables,
|
||||||
memory_config=memory_config,
|
variable_pool=variable_pool,
|
||||||
memory=memory,
|
vision_detail_config=vision_detail,
|
||||||
model_config=model_config,
|
)
|
||||||
)
|
)
|
||||||
stop = model_config.stop
|
|
||||||
|
# Get memory messages for chat mode
|
||||||
|
memory_messages = _handle_memory_chat_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Extend prompt_messages with memory messages
|
||||||
|
prompt_messages.extend(memory_messages)
|
||||||
|
|
||||||
|
# Add current query to the prompt messages
|
||||||
|
if user_query:
|
||||||
|
message = LLMNodeChatModelMessage(
|
||||||
|
text=user_query,
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
)
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_list_messages(
|
||||||
|
messages=[message],
|
||||||
|
context="",
|
||||||
|
jinja2_variables=[],
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
|
# For completion model
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_completion_template(
|
||||||
|
template=prompt_template,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get memory text for completion model
|
||||||
|
memory_text = _handle_memory_completion_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Insert histories into the prompt
|
||||||
|
prompt_content = prompt_messages[0].content
|
||||||
|
if "#histories#" in prompt_content:
|
||||||
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
|
else:
|
||||||
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
|
||||||
|
# Add current query to the prompt message
|
||||||
|
if user_query:
|
||||||
|
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
else:
|
||||||
|
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
||||||
|
logger.warning(errmsg)
|
||||||
|
raise NotSupportedPromptTypeError(errmsg)
|
||||||
|
|
||||||
|
if vision_enabled and user_files:
|
||||||
|
file_prompts = []
|
||||||
|
for file in user_files:
|
||||||
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||||
|
file_prompts.append(file_prompt)
|
||||||
|
if (
|
||||||
|
len(prompt_messages) > 0
|
||||||
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
|
and isinstance(prompt_messages[-1].content, list)
|
||||||
|
):
|
||||||
|
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
|
# Filter prompt messages
|
||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if prompt_message.is_empty():
|
if isinstance(prompt_message.content, list):
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(prompt_message.content, str):
|
|
||||||
prompt_message_content = []
|
prompt_message_content = []
|
||||||
for content_item in prompt_message.content or []:
|
for content_item in prompt_message.content:
|
||||||
# Skip image if vision is disabled
|
# Skip content if features are not defined
|
||||||
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
if not model_config.model_schema.features:
|
||||||
|
if content_item.type != PromptMessageContentType.TEXT:
|
||||||
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(content_item, ImagePromptMessageContent):
|
# Skip content if corresponding feature is not supported
|
||||||
# Override vision config if LLM node has vision config,
|
if (
|
||||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
(
|
||||||
content_item.detail = vision_detail
|
content_item.type == PromptMessageContentType.IMAGE
|
||||||
prompt_message_content.append(content_item)
|
and ModelFeature.VISION not in model_config.model_schema.features
|
||||||
elif isinstance(
|
)
|
||||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.DOCUMENT
|
||||||
|
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.VIDEO
|
||||||
|
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.AUDIO
|
||||||
|
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
):
|
):
|
||||||
prompt_message_content.append(content_item)
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
if len(prompt_message_content) > 1:
|
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||||
prompt_message.content = prompt_message_content
|
|
||||||
elif (
|
|
||||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
|
||||||
):
|
|
||||||
prompt_message.content = prompt_message_content[0].data
|
prompt_message.content = prompt_message_content[0].data
|
||||||
|
else:
|
||||||
|
prompt_message.content = prompt_message_content
|
||||||
|
if prompt_message.is_empty():
|
||||||
|
continue
|
||||||
filtered_prompt_messages.append(prompt_message)
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
if not filtered_prompt_messages:
|
if len(filtered_prompt_messages) == 0:
|
||||||
raise NoPromptFoundError(
|
raise NoPromptFoundError(
|
||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||||
|
match role:
|
||||||
|
case PromptMessageRole.USER:
|
||||||
|
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.ASSISTANT:
|
||||||
|
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.SYSTEM:
|
||||||
|
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
raise NotImplementedError(f"Role {role} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _render_jinja2_message(
|
||||||
|
*,
|
||||||
|
template: str,
|
||||||
|
jinjia2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
):
|
||||||
|
if not template:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
jinjia2_inputs = {}
|
||||||
|
for jinja2_variable in jinjia2_variables:
|
||||||
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||||
|
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||||
|
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||||
|
language=CodeLanguage.JINJA2,
|
||||||
|
code=template,
|
||||||
|
inputs=jinjia2_inputs,
|
||||||
|
)
|
||||||
|
result_text = code_execute_resp["result"]
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_list_messages(
|
||||||
|
*,
|
||||||
|
messages: Sequence[LLMNodeChatModelMessage],
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
prompt_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=message.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
else:
|
||||||
|
# Get segment group from basic message
|
||||||
|
if context:
|
||||||
|
template = message.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template = message.text
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
|
|
||||||
|
# Process segments for images
|
||||||
|
file_contents = []
|
||||||
|
for segment in segment_group.value:
|
||||||
|
if isinstance(segment, ArrayFileSegment):
|
||||||
|
for file in segment.value:
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
if isinstance(segment, FileSegment):
|
||||||
|
file = segment.value
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
|
||||||
|
# Create message with text from all segments
|
||||||
|
plain_text = segment_group.text
|
||||||
|
if plain_text:
|
||||||
|
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
if file_contents:
|
||||||
|
# Create message with image contents
|
||||||
|
prompt_message = UserPromptMessage(content=file_contents)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_rest_token(
|
||||||
|
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||||
|
) -> int:
|
||||||
|
rest_tokens = 2000
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
if model_context_tokens:
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if parameter_rule.name == "max_tokens" or (
|
||||||
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||||
|
):
|
||||||
|
max_tokens = (
|
||||||
|
model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_chat_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
memory_messages = []
|
||||||
|
# Get messages from memory for chat model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
memory_messages = memory.get_history_prompt_messages(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
)
|
||||||
|
return memory_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_completion_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> str:
|
||||||
|
memory_text = ""
|
||||||
|
# Get history text from memory for completion model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
if not memory_config.role_prefix:
|
||||||
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
|
memory_text = memory.get_history_prompt_text(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
human_prefix=memory_config.role_prefix.user,
|
||||||
|
ai_prefix=memory_config.role_prefix.assistant,
|
||||||
|
)
|
||||||
|
return memory_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_completion_template(
|
||||||
|
*,
|
||||||
|
template: LLMNodeCompletionModelPromptTemplate,
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
"""Handle completion template processing outside of LLMNode class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The completion model prompt template
|
||||||
|
context: Optional context string
|
||||||
|
jinja2_variables: Variables for jinja2 template rendering
|
||||||
|
variable_pool: Variable pool for template conversion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence of prompt messages
|
||||||
|
"""
|
||||||
|
prompt_messages = []
|
||||||
|
if template.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=template.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if context:
|
||||||
|
template_text = template.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template_text = template.text
|
||||||
|
result_text = variable_pool.convert_template(template_text).text
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
return prompt_messages
|
||||||
|
|
|
@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
|
||||||
)
|
)
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
system_query=query,
|
user_query=query,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
files=files,
|
user_files=files,
|
||||||
vision_enabled=node_data.vision.enabled,
|
vision_enabled=node_data.vision.enabled,
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
|
61
api/poetry.lock
generated
61
api/poetry.lock
generated
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
|
@ -932,10 +932,6 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
|
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
|
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
|
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"},
|
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"},
|
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"},
|
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"},
|
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
|
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
|
||||||
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
|
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
|
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
|
||||||
|
@ -948,14 +944,8 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
|
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
|
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
|
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"},
|
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"},
|
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"},
|
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"},
|
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
|
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
|
||||||
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
|
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
|
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
|
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
|
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
|
||||||
|
@ -966,24 +956,8 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
|
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
|
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
|
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"},
|
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
|
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
|
||||||
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
|
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"},
|
|
||||||
{file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"},
|
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
|
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
|
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
|
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
|
||||||
|
@ -993,10 +967,6 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
|
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
|
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
|
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"},
|
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"},
|
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"},
|
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"},
|
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
|
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
|
||||||
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
|
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
|
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
|
||||||
|
@ -1008,10 +978,6 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
|
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
|
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
|
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"},
|
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"},
|
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"},
|
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"},
|
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
|
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
|
||||||
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
|
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
|
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
|
||||||
|
@ -1024,10 +990,6 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
|
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
|
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
|
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"},
|
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"},
|
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"},
|
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"},
|
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
|
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
|
||||||
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
|
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
|
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
|
||||||
|
@ -1040,10 +1002,6 @@ files = [
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
|
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
|
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
|
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"},
|
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"},
|
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"},
|
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"},
|
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
|
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
|
||||||
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
|
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
|
||||||
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
|
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
|
||||||
|
@ -2453,6 +2411,21 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["pytest (>=6)"]
|
test = ["pytest (>=6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "faker"
|
||||||
|
version = "32.1.0"
|
||||||
|
description = "Faker is a Python package that generates fake data for you."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
|
||||||
|
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
python-dateutil = ">=2.4"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fal-client"
|
name = "fal-client"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
|
@ -11078,4 +11051,4 @@ cffi = ["cffi (>=1.11)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "2ba4b464eebc26598f290fa94713acc44c588f902176e6efa80622911d40f0ac"
|
content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"
|
||||||
|
|
|
@ -267,6 +267,7 @@ weaviate-client = "~3.21.0"
|
||||||
optional = true
|
optional = true
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
coverage = "~7.2.4"
|
coverage = "~7.2.4"
|
||||||
|
faker = "~32.1.0"
|
||||||
pytest = "~8.3.2"
|
pytest = "~8.3.2"
|
||||||
pytest-benchmark = "~4.0.0"
|
pytest-benchmark = "~4.0.0"
|
||||||
pytest-env = "~1.1.3"
|
pytest-env = "~1.1.3"
|
||||||
|
|
|
@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||||
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
|
|
@ -4,29 +4,21 @@ import pytest
|
||||||
|
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
|
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
|
||||||
|
|
||||||
|
|
||||||
def test_validate_credentials():
|
def test_validate_credentials():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||||
query="What is the capital of the United States?",
|
|
||||||
docs=[
|
|
||||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
|
||||||
"Census, Carson City had a population of 55,274.",
|
|
||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
|
||||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
|
||||||
],
|
|
||||||
score_threshold=0.8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_invoke_model():
|
def test_invoke_model():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
|
|
|
@ -1,125 +1,484 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
|
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||||
|
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||||
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
from core.workflow.nodes.end import EndStreamParam
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
ContextConfig,
|
||||||
|
LLMNodeChatModelMessage,
|
||||||
|
LLMNodeData,
|
||||||
|
ModelConfig,
|
||||||
|
VisionConfig,
|
||||||
|
VisionConfigOptions,
|
||||||
|
)
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
from models.provider import ProviderType
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||||
|
|
||||||
|
|
||||||
class TestLLMNode:
|
class MockTokenBufferMemory:
|
||||||
@pytest.fixture
|
def __init__(self, history_messages=None):
|
||||||
def llm_node(self):
|
self.history_messages = history_messages or []
|
||||||
data = LLMNodeData(
|
|
||||||
title="Test LLM",
|
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
|
||||||
prompt_template=[],
|
|
||||||
memory=None,
|
|
||||||
context=ContextConfig(enabled=False),
|
|
||||||
vision=VisionConfig(
|
|
||||||
enabled=True,
|
|
||||||
configs=VisionConfigOptions(
|
|
||||||
variable_selector=["sys", "files"],
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables={},
|
|
||||||
user_inputs={},
|
|
||||||
)
|
|
||||||
node = LLMNode(
|
|
||||||
id="1",
|
|
||||||
config={
|
|
||||||
"id": "1",
|
|
||||||
"data": data.model_dump(),
|
|
||||||
},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="1",
|
|
||||||
app_id="1",
|
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
|
||||||
workflow_id="1",
|
|
||||||
graph_config={},
|
|
||||||
user_id="1",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph=Graph(
|
|
||||||
root_node_id="1",
|
|
||||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
|
||||||
answer_dependencies={},
|
|
||||||
answer_generate_route={},
|
|
||||||
),
|
|
||||||
end_stream_param=EndStreamParam(
|
|
||||||
end_dependencies={},
|
|
||||||
end_stream_variable_selector_mapping={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
graph_runtime_state=GraphRuntimeState(
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
start_at=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment(self, llm_node):
|
def get_history_prompt_messages(
|
||||||
file = File(
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
if message_limit is not None:
|
||||||
|
return self.history_messages[-message_limit * 2 :]
|
||||||
|
return self.history_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_node():
|
||||||
|
data = LLMNodeData(
|
||||||
|
title="Test LLM",
|
||||||
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
|
prompt_template=[],
|
||||||
|
memory=None,
|
||||||
|
context=ContextConfig(enabled=False),
|
||||||
|
vision=VisionConfig(
|
||||||
|
enabled=True,
|
||||||
|
configs=VisionConfigOptions(
|
||||||
|
variable_selector=["sys", "files"],
|
||||||
|
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
node = LLMNode(
|
||||||
|
id="1",
|
||||||
|
config={
|
||||||
|
"id": "1",
|
||||||
|
"data": data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph=Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
|
),
|
||||||
|
end_stream_param=EndStreamParam(
|
||||||
|
end_dependencies={},
|
||||||
|
end_stream_variable_selector_mapping={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config():
|
||||||
|
# Create actual provider and model type instances
|
||||||
|
model_provider_factory = ModelProviderFactory()
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Create a ProviderModelBundle
|
||||||
|
provider_model_bundle = ProviderModelBundle(
|
||||||
|
configuration=ProviderConfiguration(
|
||||||
|
tenant_id="1",
|
||||||
|
provider=provider_instance.get_provider_schema(),
|
||||||
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
|
system_configuration=SystemConfiguration(enabled=False),
|
||||||
|
custom_configuration=CustomConfiguration(provider=None),
|
||||||
|
model_settings=[],
|
||||||
|
),
|
||||||
|
provider_instance=provider_instance,
|
||||||
|
model_type_instance=model_type_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and return a ModelConfigWithCredentialsEntity
|
||||||
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
model_schema=AIModelEntity(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={},
|
||||||
|
),
|
||||||
|
mode="chat",
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_file_segment(llm_node):
|
||||||
|
file = File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="1",
|
||||||
|
)
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == [file]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_file_segment(llm_node):
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
id="1",
|
id="1",
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
id="2",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test2.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == files
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_none_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_any_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||||
|
prompt_template = []
|
||||||
|
llm_node.node_data.prompt_template = prompt_template
|
||||||
|
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
]
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
fake_query = faker.sentence()
|
||||||
assert result == [file]
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
files = [
|
user_query=fake_query,
|
||||||
File(
|
user_files=files,
|
||||||
id="1",
|
context=None,
|
||||||
tenant_id="test",
|
memory=None,
|
||||||
type=FileType.IMAGE,
|
model_config=model_config,
|
||||||
filename="test1.jpg",
|
prompt_template=prompt_template,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
memory_config=None,
|
||||||
related_id="1",
|
vision_enabled=False,
|
||||||
),
|
vision_detail=fake_vision_detail,
|
||||||
File(
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
id="2",
|
jinja2_variables=[],
|
||||||
tenant_id="test",
|
)
|
||||||
type=FileType.IMAGE,
|
|
||||||
filename="test2.jpg",
|
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
||||||
related_id="2",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||||
assert result == files
|
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment(self, llm_node):
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||||
assert result == []
|
# Setup dify config
|
||||||
|
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||||
|
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
# Generate fake values for prompt template
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
fake_assistant_prompt = faker.sentence()
|
||||||
|
fake_query = faker.sentence()
|
||||||
|
fake_context = faker.sentence()
|
||||||
|
fake_window_size = faker.random_int(min=1, max=3)
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
# Setup mock memory with history messages
|
||||||
assert result == []
|
mock_history = [
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
]
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
# Setup memory configuration
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
memory_config = MemoryConfig(
|
||||||
assert result == []
|
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||||
|
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||||
|
query_prompt_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||||
|
|
||||||
|
# Test scenarios covering different file input combinations
|
||||||
|
test_scenarios = [
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="No files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
features=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=None,
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(content=fake_query),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="User files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[
|
||||||
|
File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=fake_query),
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File without vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File with video file and vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.VIDEO,
|
||||||
|
filename="test1.mp4",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
extension="mp4",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in test_scenarios:
|
||||||
|
model_config.model_schema.features = scenario.features
|
||||||
|
|
||||||
|
for k, v in scenario.file_variables.items():
|
||||||
|
selector = k.split(".")
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
|
user_query=scenario.user_query,
|
||||||
|
user_files=scenario.user_files,
|
||||||
|
context=fake_context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config,
|
||||||
|
prompt_template=scenario.prompt_template,
|
||||||
|
memory_config=memory_config,
|
||||||
|
vision_enabled=scenario.vision_enabled,
|
||||||
|
vision_detail=scenario.vision_detail,
|
||||||
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||||
|
assert (
|
||||||
|
prompt_messages == scenario.expected_messages
|
||||||
|
), f"Message content mismatch in scenario: {scenario.description}"
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNodeTestScenario(BaseModel):
|
||||||
|
"""Test scenario for LLM node testing."""
|
||||||
|
|
||||||
|
description: str = Field(..., description="Description of the test scenario")
|
||||||
|
user_query: str = Field(..., description="User query input")
|
||||||
|
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||||
|
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||||
|
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||||
|
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||||
|
window_size: int = Field(..., description="Window size for memory")
|
||||||
|
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||||
|
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||||
|
default_factory=dict, description="List of file variables"
|
||||||
|
)
|
||||||
|
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
|
@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
|
||||||
}
|
}
|
||||||
|
|
||||||
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
||||||
if (fileMimetype)
|
|
||||||
return mime.getExtension(fileMimetype) || ''
|
|
||||||
|
|
||||||
if (isRemote)
|
|
||||||
return ''
|
|
||||||
|
|
||||||
if (fileName) {
|
if (fileName) {
|
||||||
const fileNamePair = fileName.split('.')
|
const fileNamePair = fileName.split('.')
|
||||||
const fileNamePairLength = fileNamePair.length
|
const fileNamePairLength = fileNamePair.length
|
||||||
|
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
|
||||||
return fileNamePair[fileNamePairLength - 1]
|
return fileNamePair[fileNamePairLength - 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (fileMimetype)
|
||||||
|
return mime.getExtension(fileMimetype) || ''
|
||||||
|
|
||||||
|
if (isRemote)
|
||||||
|
return ''
|
||||||
|
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
|
||||||
onEditionTypeChange={onEditionTypeChange}
|
onEditionTypeChange={onEditionTypeChange}
|
||||||
varList={varList}
|
varList={varList}
|
||||||
handleAddVariable={handleAddVariable}
|
handleAddVariable={handleAddVariable}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
handleStop,
|
handleStop,
|
||||||
varInputs,
|
varInputs,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
} = useConfig(id, data)
|
} = useConfig(id, data)
|
||||||
|
|
||||||
const model = inputs.model
|
const model = inputs.model
|
||||||
|
@ -194,7 +195,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
list={inputs.prompt_config?.jinja2_variables || []}
|
list={inputs.prompt_config?.jinja2_variables || []}
|
||||||
onChange={handleVarListChange}
|
onChange={handleVarListChange}
|
||||||
onVarNameChange={handleVarNameChange}
|
onVarNameChange={handleVarNameChange}
|
||||||
filterVar={filterVar}
|
filterVar={filterJinjia2InputVar}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
)}
|
)}
|
||||||
|
@ -233,6 +234,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
hasSetBlockStatus={hasSetBlockStatus}
|
hasSetBlockStatus={hasSetBlockStatus}
|
||||||
nodesOutputVars={availableVars}
|
nodesOutputVars={availableVars}
|
||||||
availableNodes={availableNodesWithParent}
|
availableNodes={availableNodesWithParent}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
||||||
|
|
|
@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
||||||
}, [inputs, setInputs])
|
}, [inputs, setInputs])
|
||||||
|
|
||||||
const filterInputVar = useCallback((varPayload: Var) => {
|
const filterInputVar = useCallback((varPayload: Var) => {
|
||||||
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
||||||
handleRun,
|
handleRun,
|
||||||
handleStop,
|
handleStop,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user