chore(list_operator): refine exception handling for error specificity (#10206)

This commit is contained in:
-LAN- 2024-11-03 11:55:19 +08:00 committed by Yeuoly
parent 7c45859594
commit b01e7d778e
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 112 additions and 57 deletions

View File

@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal
from typing import Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
@ -36,23 +47,59 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value
try:
# Filter
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
@ -69,9 +116,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable
# Order
if self.node_data.order_by.enabled:
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
@ -83,23 +132,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
# Slice
if self.node_data.limit.enabled:
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
variable = variable.model_copy(update={"value": result})
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
return variable.model_copy(update={"value": result})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "url":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):