feat: add if elif (#6094)

This commit is contained in:
Joe 2024-07-10 18:22:51 +08:00 committed by GitHub
parent ebba124c5c
commit 5a3e09518c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 157 additions and 81 deletions

View File

@ -5,10 +5,6 @@ from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
"""
class Condition(BaseModel):
"""
Condition entity
@ -22,5 +18,21 @@ class IfElseNodeData(BaseNodeData):
]
value: Optional[str] = None
logical_operator: Literal["and", "or"] = "and"
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
"""
class Case(BaseModel):
"""
Case entity representing a single logical condition group
"""
case_id: str
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = None
cases: Optional[list[Case]] = None

View File

@ -4,7 +4,8 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
@ -29,68 +30,46 @@ class IfElseNode(BaseNode):
"condition_results": []
}
try:
logical_operator = node_data.logical_operator
input_conditions = []
for condition in node_data.conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
final_result = False
selected_case_id = None
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
process_datas["condition_results"].append(
{
"group": case.model_dump(),
"results": group_result,
"final_result": final_result,
}
)
expected_value = condition.value
# Break if a case passes (logical short-circuit)
if final_result:
selected_case_id = case.case_id # Capture the ID of the passing case
break
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
else:
# Fallback to old structure if cases are not defined
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
process_datas["condition_results"].append(
{
"group": "default",
"results": group_result,
"final_result": final_result
}
)
node_inputs["conditions"] = input_conditions
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
process_datas["condition_results"].append({
**input_condition,
"result": compare_result
})
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -99,21 +78,106 @@ class IfElseNode(BaseNode):
error=str(e)
)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
else:
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
outputs = {
"result": final_result
}
if node_data.cases:
outputs["selected_case_id"] = selected_case_id
return NodeRunResult(
data = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_datas,
edge_source_handle="false" if not compare_result else "true",
outputs={
"result": compare_result
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
outputs=outputs
)
return data
def evaluate_condition(
self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
) -> bool:
"""
Evaluate condition
:param actual_value: actual value
:param expected_value: expected value
:param comparison_operator: comparison operator
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
input_conditions = []
group_result = []
for condition in conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
if condition.value is not None:
variable_template_parser = VariableTemplateParser(template=condition.value)
expected_value = variable_template_parser.extract_variable_selectors()
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
expected_value = variable_template_parser.format({variable_selector.variable: value})
else:
expected_value = condition.value
else:
expected_value = None
comparison_operator = condition.comparison_operator
input_conditions.append(
{
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": comparison_operator
}
)
result = self.evaluate_condition(actual_value, expected_value, comparison_operator)
group_result.append(result)
return input_conditions, group_result
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains