mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
chore: improve position map conversion and tolerate empty position yaml file (#6541)
This commit is contained in:
parent
c8da4a1b7e
commit
20268708cc
|
@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
|||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_file_name = os.path.join(folder_path, file_name)
|
||||
if not position_file_name or not os.path.exists(position_file_name):
|
||||
return {}
|
||||
|
||||
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||
position_map = {}
|
||||
index = 0
|
||||
for _, name in enumerate(positions):
|
||||
if name and isinstance(name, str):
|
||||
position_map[name.strip()] = index
|
||||
index += 1
|
||||
return position_map
|
||||
position_file_path = os.path.join(folder_path, file_name)
|
||||
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
|
||||
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
|
|
|
@ -162,7 +162,7 @@ class AIModel(ABC):
|
|||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
|
|
|
@ -44,7 +44,7 @@ class ModelProvider(ABC):
|
|||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
|
||||
yaml_data = load_yaml_file(yaml_path)
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
|
|
|
@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
provider = self.__class__.__module__.split('.')[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
||||
try:
|
||||
provider_yaml = load_yaml_file(yaml_path)
|
||||
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
|
||||
|
||||
|
@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
for tool_file in tool_files:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file))
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
|
|
|
@ -1,35 +1,31 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file to a dict
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return empty dict if error occurs and the error will be logged in warning level
|
||||
if True, return default_value if error occurs and the error will be logged in warning level
|
||||
if False, raise error if error occurs
|
||||
:return: a dict of the YAML content
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
try:
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
|
||||
|
||||
with open(file_path, encoding='utf-8') as file:
|
||||
with open(file_path, encoding='utf-8') as yaml_file:
|
||||
try:
|
||||
return yaml.safe_load(file)
|
||||
return yaml.safe_load(yaml_file)
|
||||
except Exception as e:
|
||||
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
|
||||
except FileNotFoundError as e:
|
||||
logger.debug(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
logger.warning(f'Failed to load YAML file {file_path}: {e}')
|
||||
return {}
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
|
|
@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
|||
return str(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
|
||||
"""\
|
||||
# - commented1
|
||||
# - commented2
|
||||
-
|
||||
-
|
||||
|
||||
"""))
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def test_position_helper(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
|
@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml):
|
|||
'third': 2,
|
||||
'forth': 3,
|
||||
}
|
||||
|
||||
|
||||
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_empty_commented_positions_yaml,
|
||||
file_name='example_positions_all_commented.yaml')
|
||||
assert position_map == {}
|
||||
|
|
|
@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file():
|
|||
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||
assert load_yaml_file(file_path='') == {}
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
|
||||
|
||||
|
||||
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
|
@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
|
|||
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||
# yaml syntax error
|
||||
with pytest.raises(YAMLError):
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file)
|
||||
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
|
||||
|
||||
# ignore error
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
|
||||
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user