mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)
This commit is contained in:
parent
5e42e90abc
commit
4f5f27cf2b
|
@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
|
|||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
|
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
|
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
|
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation_id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
|
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
|
|
@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
|
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
|
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
|
|
|
@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
|
@ -67,8 +67,8 @@ class WorkflowAppRunner:
|
|||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
|
|
|
@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
|||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
|
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState(
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, Union
|
|||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
|
|||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
|
|
@ -6,20 +6,20 @@ from typing_extensions import deprecated
|
|||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
system_variables: Mapping[SystemVariableKey, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
|
@ -68,7 +68,7 @@ class VariablePool:
|
|||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if value is None:
|
||||
return
|
||||
|
@ -95,13 +95,13 @@ class VariablePool:
|
|||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@deprecated('This method is deprecated, use `get` instead.')
|
||||
@deprecated("This method is deprecated, use `get` instead.")
|
||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
@ -116,7 +116,7 @@ class VariablePool:
|
|||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
|
|
@ -1,25 +1,13 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class SystemVariable(str, Enum):
|
||||
class SystemVariableKey(str, Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
DIALOGUE_COUNT = 'dialogue_count'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
||||
QUERY = "query"
|
||||
FILES = "files"
|
||||
CONVERSATION_ID = "conversation_id"
|
||||
USER_ID = "user_id"
|
||||
DIALOGUE_COUNT = "dialogue_count"
|
||||
|
|
|
@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
|||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
|
@ -94,7 +94,7 @@ class LLMNode(BaseNode):
|
|||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
|
@ -335,7 +335,7 @@ class LLMNode(BaseNode):
|
|||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
||||
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
|
@ -500,7 +500,7 @@ class LLMNode(BaseNode):
|
|||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
|
@ -672,10 +672,10 @@ class LLMNode(BaseNode):
|
|||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
@ -17,16 +17,16 @@ class StartNode(BaseNode):
|
|||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
||||
node_inputs = dict(variable_pool.user_inputs)
|
||||
system_inputs = variable_pool.system_variables
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=cleaned_inputs,
|
||||
outputs=cleaned_inputs
|
||||
inputs=node_inputs,
|
||||
outputs=node_inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
@ -141,7 +141,7 @@ class ToolNode(BaseNode):
|
|||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import ModelProviderFactory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
|
@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather today?',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
|
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
|
|||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from extensions.ext_database import db
|
||||
|
@ -29,8 +29,8 @@ def test_execute_answer():
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'weather'], 'sunny')
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
|
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -119,8 +119,8 @@ def test_execute_if_else_result_true():
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
|
@ -182,8 +182,8 @@ def test_execute_if_else_result_false():
|
|||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
|
|
@ -4,7 +4,7 @@ from uuid import uuid4
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
|
@ -42,7 +42,7 @@ def test_overwrite_string_variable():
|
|||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
@ -93,7 +93,7 @@ def test_append_variable_to_array():
|
|||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
@ -137,7 +137,7 @@ def test_clear_array():
|
|||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
|
Loading…
Reference in New Issue
Block a user