mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)
This commit is contained in:
parent
5e42e90abc
commit
4f5f27cf2b
|
@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
|
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
args: dict,
|
args: dict,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
) -> Union[dict, Generator[dict, None, None]]:
|
):
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
# get conversation
|
# get conversation
|
||||||
conversation = None
|
conversation = None
|
||||||
if args.get('conversation_id'):
|
conversation_id = args.get('conversation_id')
|
||||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
if conversation_id:
|
||||||
|
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args['files'] if args.get('files') else []
|
files = args['files'] if args.get('files') else []
|
||||||
|
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user: Account,
|
user: Account,
|
||||||
args: dict,
|
args: dict,
|
||||||
stream: bool = True) \
|
stream: bool = True):
|
||||||
-> Union[dict, Generator[dict, None, None]]:
|
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
# get conversation
|
# get conversation
|
||||||
conversation = None
|
conversation = None
|
||||||
if args.get('conversation_id'):
|
conversation_id = args.get('conversation_id')
|
||||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
if conversation_id:
|
||||||
|
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||||
|
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
conversation: Conversation | None = None,
|
conversation: Conversation | None = None,
|
||||||
stream: bool = True) \
|
stream: bool = True):
|
||||||
-> Union[dict, Generator[dict, None, None]]:
|
|
||||||
is_first_conversation = False
|
is_first_conversation = False
|
||||||
if not conversation:
|
if not conversation:
|
||||||
is_first_conversation = True
|
is_first_conversation = True
|
||||||
|
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = {
|
system_inputs = {
|
||||||
SystemVariable.QUERY: query,
|
SystemVariableKey.QUERY: query,
|
||||||
SystemVariable.FILES: files,
|
SystemVariableKey.FILES: files,
|
||||||
SystemVariable.CONVERSATION_ID: conversation_id,
|
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||||
SystemVariable.USER_ID: user_id,
|
SystemVariableKey.USER_ID: user_id,
|
||||||
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||||
}
|
}
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
|
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except (ValueError, InvokeError) as e:
|
||||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
|
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
_workflow: Workflow
|
_workflow: Workflow
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
# Deprecated
|
# Deprecated
|
||||||
_workflow_system_variables: dict[SystemVariable, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
_iteration_nested_relations: dict[str, list[str]]
|
_iteration_nested_relations: dict[str, list[str]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
self._message = message
|
self._message = message
|
||||||
# Deprecated
|
# Deprecated
|
||||||
self._workflow_system_variables = {
|
self._workflow_system_variables = {
|
||||||
SystemVariable.QUERY: message.query,
|
SystemVariableKey.QUERY: message.query,
|
||||||
SystemVariable.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||||
SystemVariable.USER_ID: user_id,
|
SystemVariableKey.USER_ID: user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = AdvancedChatTaskState(
|
self._task_state = AdvancedChatTaskState(
|
||||||
|
|
|
@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -67,8 +67,8 @@ class WorkflowAppRunner:
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = {
|
system_inputs = {
|
||||||
SystemVariable.FILES: files,
|
SystemVariableKey.FILES: files,
|
||||||
SystemVariable.USER_ID: user_id,
|
SystemVariableKey.USER_ID: user_id,
|
||||||
}
|
}
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
|
|
|
@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
_task_state: WorkflowTaskState
|
_task_state: WorkflowTaskState
|
||||||
_application_generate_entity: WorkflowAppGenerateEntity
|
_application_generate_entity: WorkflowAppGenerateEntity
|
||||||
_workflow_system_variables: dict[SystemVariable, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
_iteration_nested_relations: dict[str, list[str]]
|
_iteration_nested_relations: dict[str, list[str]]
|
||||||
|
|
||||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
|
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||||
|
|
||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._workflow_system_variables = {
|
self._workflow_system_variables = {
|
||||||
SystemVariable.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
SystemVariable.USER_ID: user_id
|
SystemVariableKey.USER_ID: user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState(
|
self._task_state = WorkflowTaskState(
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
|
||||||
_workflow: Workflow
|
_workflow: Workflow
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||||
_workflow_system_variables: dict[SystemVariable, Any]
|
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||||
|
|
|
@ -6,20 +6,20 @@ from typing_extensions import deprecated
|
||||||
|
|
||||||
from core.app.segments import Segment, Variable, factory
|
from core.app.segments import Segment, Variable, factory
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
|
||||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||||
|
|
||||||
|
|
||||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||||
|
|
||||||
|
|
||||||
class VariablePool:
|
class VariablePool:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
system_variables: Mapping[SystemVariable, Any],
|
system_variables: Mapping[SystemVariableKey, Any],
|
||||||
user_inputs: Mapping[str, Any],
|
user_inputs: Mapping[str, Any],
|
||||||
environment_variables: Sequence[Variable],
|
environment_variables: Sequence[Variable],
|
||||||
conversation_variables: Sequence[Variable] | None = None,
|
conversation_variables: Sequence[Variable] | None = None,
|
||||||
|
@ -68,7 +68,7 @@ class VariablePool:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
if len(selector) < 2:
|
if len(selector) < 2:
|
||||||
raise ValueError('Invalid selector')
|
raise ValueError("Invalid selector")
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
|
@ -95,13 +95,13 @@ class VariablePool:
|
||||||
ValueError: If the selector is invalid.
|
ValueError: If the selector is invalid.
|
||||||
"""
|
"""
|
||||||
if len(selector) < 2:
|
if len(selector) < 2:
|
||||||
raise ValueError('Invalid selector')
|
raise ValueError("Invalid selector")
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@deprecated('This method is deprecated, use `get` instead.')
|
@deprecated("This method is deprecated, use `get` instead.")
|
||||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||||
"""
|
"""
|
||||||
Retrieves the value from the variable pool based on the given selector.
|
Retrieves the value from the variable pool based on the given selector.
|
||||||
|
@ -116,7 +116,7 @@ class VariablePool:
|
||||||
ValueError: If the selector is invalid.
|
ValueError: If the selector is invalid.
|
||||||
"""
|
"""
|
||||||
if len(selector) < 2:
|
if len(selector) < 2:
|
||||||
raise ValueError('Invalid selector')
|
raise ValueError("Invalid selector")
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||||
return value.to_object() if value else None
|
return value.to_object() if value else None
|
||||||
|
|
|
@ -1,25 +1,13 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class SystemVariable(str, Enum):
|
class SystemVariableKey(str, Enum):
|
||||||
"""
|
"""
|
||||||
System Variables.
|
System Variables.
|
||||||
"""
|
"""
|
||||||
QUERY = 'query'
|
|
||||||
FILES = 'files'
|
|
||||||
CONVERSATION_ID = 'conversation_id'
|
|
||||||
USER_ID = 'user_id'
|
|
||||||
DIALOGUE_COUNT = 'dialogue_count'
|
|
||||||
|
|
||||||
@classmethod
|
QUERY = "query"
|
||||||
def value_of(cls, value: str):
|
FILES = "files"
|
||||||
"""
|
CONVERSATION_ID = "conversation_id"
|
||||||
Get value of given system variable.
|
USER_ID = "user_id"
|
||||||
|
DIALOGUE_COUNT = "dialogue_count"
|
||||||
:param value: system variable value
|
|
||||||
:return: system variable
|
|
||||||
"""
|
|
||||||
for system_variable in cls:
|
|
||||||
if system_variable.value == value:
|
|
||||||
return system_variable
|
|
||||||
raise ValueError(f'invalid system variable value {value}')
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.llm.entities import (
|
from core.workflow.nodes.llm.entities import (
|
||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
|
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||||
if node_data.memory else None,
|
if node_data.memory else None,
|
||||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
@ -335,7 +335,7 @@ class LLMNode(BaseNode):
|
||||||
if not node_data.vision.enabled:
|
if not node_data.vision.enabled:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -500,7 +500,7 @@ class LLMNode(BaseNode):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get conversation id
|
# get conversation id
|
||||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||||
if conversation_id is None:
|
if conversation_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -672,10 +672,10 @@ class LLMNode(BaseNode):
|
||||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||||
|
|
||||||
if node_data.vision.enabled:
|
if node_data.vision.enabled:
|
||||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||||
|
|
||||||
if node_data.memory:
|
if node_data.memory:
|
||||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||||
|
|
||||||
if node_data.prompt_config:
|
if node_data.prompt_config:
|
||||||
enable_jinja = False
|
enable_jinja = False
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.start.entities import StartNodeData
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
@ -17,16 +17,16 @@ class StartNode(BaseNode):
|
||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get cleaned inputs
|
node_inputs = dict(variable_pool.user_inputs)
|
||||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
system_inputs = variable_pool.system_variables
|
||||||
|
|
||||||
for var in variable_pool.system_variables:
|
for var in system_inputs:
|
||||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=cleaned_inputs,
|
inputs=node_inputs,
|
||||||
outputs=cleaned_inputs
|
outputs=node_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
@ -141,7 +141,7 @@ class ToolNode(BaseNode):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||||
assert isinstance(variable, ArrayAnyVariable)
|
assert isinstance(variable, ArrayAnyVariable)
|
||||||
return list(variable.value) if variable else []
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers import ModelProviderFactory
|
from core.model_runtime.model_providers import ModelProviderFactory
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
pool.add(['abc', 'output'], 'sunny')
|
pool.add(['abc', 'output'], 'sunny')
|
||||||
|
|
||||||
|
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
pool.add(['abc', 'output'], 'sunny')
|
pool.add(['abc', 'output'], 'sunny')
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
|
||||||
result = node.run(pool)
|
result = node.run(pool)
|
||||||
|
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
|
||||||
result = node.run(pool)
|
result = node.run(pool)
|
||||||
|
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
|
||||||
result = node.run(pool)
|
result = node.run(pool)
|
||||||
|
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
|
||||||
result = node.run(pool)
|
result = node.run(pool)
|
||||||
|
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
|
||||||
result = node.run(pool)
|
result = node.run(pool)
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from core.app.segments import SecretVariable, StringSegment, parser
|
from core.app.segments import SecretVariable, StringSegment, parser
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
|
||||||
|
|
||||||
def test_segment_group_to_text():
|
def test_segment_group_to_text():
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables={
|
||||||
SystemVariable('user_id'): 'fake-user-id',
|
SystemVariableKey('user_id'): 'fake-user-id',
|
||||||
},
|
},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[
|
environment_variables=[
|
||||||
|
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
|
||||||
def test_convert_variable_to_segment_group():
|
def test_convert_variable_to_segment_group():
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={
|
system_variables={
|
||||||
SystemVariable('user_id'): 'fake-user-id',
|
SystemVariableKey('user_id'): 'fake-user-id',
|
||||||
},
|
},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
|
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -29,8 +29,8 @@ def test_execute_answer():
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
pool.add(['start', 'weather'], 'sunny')
|
pool.add(['start', 'weather'], 'sunny')
|
||||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||||
|
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||||
|
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
|
||||||
|
|
||||||
# construct variable pool
|
# construct variable pool
|
||||||
pool = VariablePool(system_variables={
|
pool = VariablePool(system_variables={
|
||||||
SystemVariable.FILES: [],
|
SystemVariableKey.FILES: [],
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariableKey.USER_ID: 'aaa'
|
||||||
}, user_inputs={}, environment_variables=[])
|
}, user_inputs={}, environment_variables=[])
|
||||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||||
|
|
|
@ -4,7 +4,7 @@ from uuid import uuid4
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.segments import ArrayStringVariable, StringVariable
|
from core.app.segments import ArrayStringVariable, StringVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariable
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
|
||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
|
@ -93,7 +93,7 @@ def test_append_variable_to_array():
|
||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
|
@ -137,7 +137,7 @@ def test_clear_array():
|
||||||
)
|
)
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
environment_variables=[],
|
environment_variables=[],
|
||||||
conversation_variables=[conversation_variable],
|
conversation_variables=[conversation_variable],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user