refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)

This commit is contained in:
-LAN- 2024-08-20 17:52:06 +08:00 committed by GitHub
parent 5e42e90abc
commit 4f5f27cf2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 106 additions and 118 deletions

View File

@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]: ):
""" """
Generate App response. Generate App response.
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): conversation_id = args.get('conversation_id')
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args['files'] if args.get('files') else []
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str, node_id: str,
user: Account, user: Account,
args: dict, args: dict,
stream: bool = True) \ stream: bool = True):
-> Union[dict, Generator[dict, None, None]]:
""" """
Generate App response. Generate App response.
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): conversation_id = args.get('conversation_id')
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config( app_config = AdvancedChatAppConfigManager.get_app_config(
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None, conversation: Conversation | None = None,
stream: bool = True) \ stream: bool = True):
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False is_first_conversation = False
if not conversation: if not conversation:
is_first_conversation = True is_first_conversation = True
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create a variable pool. # Create a variable pool.
system_inputs = { system_inputs = {
SystemVariable.QUERY: query, SystemVariableKey.QUERY: query,
SystemVariable.FILES: files, SystemVariableKey.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id, SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
} }
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created from events.message_event import message_was_created
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
# Deprecated # Deprecated
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]] _iteration_nested_relations: dict[str, list[str]]
def __init__( def __init__(
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message = message self._message = message
# Deprecated # Deprecated
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariable.QUERY: message.query, SystemVariableKey.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id, SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
} }
self._task_state = AdvancedChatTaskState( self._task_state = AdvancedChatTaskState(

View File

@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db from extensions.ext_database import db
@ -67,8 +67,8 @@ class WorkflowAppRunner:
# Create a variable pool. # Create a variable pool.
system_inputs = { system_inputs = {
SystemVariable.FILES: files, SystemVariableKey.FILES: files,
SystemVariable.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
} }
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,

View File

@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]] _iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id SystemVariableKey.USER_ID: user_id
} }
self._task_state = WorkflowTaskState( self._task_state = WorkflowTaskState(

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from models.account import Account from models.account import Account
from models.model import EndUser from models.model import EndUser
from models.workflow import Workflow from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -6,20 +6,20 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar] VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys' SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = 'env' ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = 'conversation' CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool: class VariablePool:
def __init__( def __init__(
self, self,
system_variables: Mapping[SystemVariable, Any], system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any], user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None, conversation_variables: Sequence[Variable] | None = None,
@ -68,7 +68,7 @@ class VariablePool:
None None
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
if value is None: if value is None:
return return
@ -95,13 +95,13 @@ class VariablePool:
ValueError: If the selector is invalid. ValueError: If the selector is invalid.
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key) value = self._variable_dictionary[selector[0]].get(hash_key)
return value return value
@deprecated('This method is deprecated, use `get` instead.') @deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None: def get_any(self, selector: Sequence[str], /) -> Any | None:
""" """
Retrieves the value from the variable pool based on the given selector. Retrieves the value from the variable pool based on the given selector.
@ -116,7 +116,7 @@ class VariablePool:
ValueError: If the selector is invalid. ValueError: If the selector is invalid.
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key) value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None return value.to_object() if value else None

View File

@ -1,25 +1,13 @@
from enum import Enum from enum import Enum
class SystemVariable(str, Enum): class SystemVariableKey(str, Enum):
""" """
System Variables. System Variables.
""" """
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod QUERY = "query"
def value_of(cls, value: str): FILES = "files"
""" CONVERSATION_ID = "conversation_id"
Get value of given system variable. USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
# fetch prompt messages # fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data, node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None, if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs, inputs=inputs,
@ -335,7 +335,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled: if not node_data.vision.enabled:
return [] return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files: if not files:
return [] return []
@ -500,7 +500,7 @@ class LLMNode(BaseNode):
return None return None
# get conversation id # get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None: if conversation_id is None:
return None return None
@ -672,10 +672,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled: if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory: if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config: if node_data.prompt_config:
enable_jinja = False enable_jinja = False

View File

@ -1,7 +1,7 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -17,16 +17,16 @@ class StartNode(BaseNode):
:param variable_pool: variable pool :param variable_pool: variable pool
:return: :return:
""" """
# Get cleaned inputs node_inputs = dict(variable_pool.user_inputs)
cleaned_inputs = dict(variable_pool.user_inputs) system_inputs = variable_pool.system_variables
for var in variable_pool.system_variables: for var in system_inputs:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs, inputs=node_inputs,
outputs=cleaned_inputs outputs=node_inputs
) )
@classmethod @classmethod

View File

@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -141,7 +141,7 @@ class ToolNode(BaseNode):
return result return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value]) variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable) assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else [] return list(variable.value) if variable else []

View File

@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?', SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny') pool.add(['abc', 'output'], 'sunny')
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?', SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny') pool.add(['abc', 'output'], 'sunny')

View File

@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db from extensions.ext_database import db
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF', SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
result = node.run(pool) result = node.run(pool)
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF', SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
result = node.run(pool) result = node.run(pool)
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF', SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
result = node.run(pool) result = node.run(pool)
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF', SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
result = node.run(pool) result = node.run(pool)
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF', SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
result = node.run(pool) result = node.run(pool)

View File

@ -1,13 +1,13 @@
from core.app.segments import SecretVariable, StringSegment, parser from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter from core.helper import encrypter
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text(): def test_segment_group_to_text():
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={ system_variables={
SystemVariable('user_id'): 'fake-user-id', SystemVariableKey('user_id'): 'fake-user-id',
}, },
user_inputs={}, user_inputs={},
environment_variables=[ environment_variables=[
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group(): def test_convert_variable_to_segment_group():
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={ system_variables={
SystemVariable('user_id'): 'fake-user-id', SystemVariableKey('user_id'): 'fake-user-id',
}, },
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db from extensions.ext_database import db
@ -29,8 +29,8 @@ def test_execute_answer():
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny') pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.') pool.add(['llm', 'text'], 'You are a helpful AI.')

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db from extensions.ext_database import db
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def']) pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def']) pool.add(['start', 'array_not_contains'], ['ac', 'def'])
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
# construct variable pool # construct variable pool
pool = VariablePool(system_variables={ pool = VariablePool(system_variables={
SystemVariable.FILES: [], SystemVariableKey.FILES: [],
SystemVariable.USER_ID: 'aaa' SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[]) }, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def']) pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def']) pool.add(['start', 'array_not_contains'], ['ab', 'def'])

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
) )
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
@ -93,7 +93,7 @@ def test_append_variable_to_array():
) )
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
@ -137,7 +137,7 @@ def test_clear_array():
) )
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],