refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)

This commit is contained in:
-LAN- 2024-08-20 17:52:06 +08:00 committed by GitHub
parent 5e42e90abc
commit 4f5f27cf2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 106 additions and 118 deletions

View File

@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
):
"""
Generate App response.
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files
files = args['files'] if args.get('files') else []
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
"""
Generate App response.
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(

View File

@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -67,8 +67,8 @@ class WorkflowAppRunner:
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,

View File

@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -6,20 +6,20 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
@ -68,7 +68,7 @@ class VariablePool:
None
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
if value is None:
return
@ -95,13 +95,13 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
@deprecated('This method is deprecated, use `get` instead.')
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
@ -116,7 +116,7 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

View File

@ -1,25 +1,13 @@
from enum import Enum
class SystemVariable(str, Enum):
class SystemVariableKey(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"

View File

@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@ -335,7 +335,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files:
return []
@ -500,7 +500,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
@ -672,10 +672,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False

View File

@ -1,7 +1,7 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -17,16 +17,16 @@ class StartNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(variable_pool.user_inputs)
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables
for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
inputs=node_inputs,
outputs=node_inputs
)
@classmethod

View File

@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -141,7 +141,7 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []

View File

@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')

View File

@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
result = node.run(pool)

View File

@ -1,13 +1,13 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[],

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db
@ -29,8 +29,8 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')

View File

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -93,7 +93,7 @@ def test_append_variable_to_array():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -137,7 +137,7 @@ def test_clear_array():
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],