mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge branch 'feat/custom-model&tool-order' into deploy/dev
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
commit
da0363d579
|
@ -485,24 +485,24 @@ class PositionConfig(BaseSettings):
|
|||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != '']
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != '']
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != '']
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != '']
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
# Used to mark a position as excluded.
|
||||
# See api/core/helper/position_helper.py for more details.
|
||||
POSITION_EXCLUDED = -999
|
|
@ -4,7 +4,6 @@ from collections.abc import Callable
|
|||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from constants.position import POSITION_EXCLUDED
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
|
@ -30,11 +29,9 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -
|
|||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
|
||||
return sort_and_filter_position_map(
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
|
||||
include_list=dify_config.POSITION_TOOL_INCLUDES_LIST,
|
||||
exclude_list=dify_config.POSITION_TOOL_EXCLUDES_LIST
|
||||
)
|
||||
|
||||
|
||||
|
@ -46,46 +43,62 @@ def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml
|
|||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
return sort_and_filter_position_map(
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||
include_list=dify_config.POSITION_PROVIDER_INCLUDES_LIST,
|
||||
exclude_list=dify_config.POSITION_PROVIDER_EXCLUDES_LIST
|
||||
)
|
||||
|
||||
|
||||
def sort_and_filter_position_map(original_position_map: dict[str, int], pin_list: list[str], include_list: list[str], exclude_list: list[str]) -> dict[str, int]:
|
||||
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Sort and filter the positions
|
||||
Pin the items in the pin list to the beginning of the position map.
|
||||
:param position_map: the position map to be sorted and filtered
|
||||
:param pin_list: the list of pins to be put at the beginning
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
:return: the sorted and filtered position map
|
||||
:return: the sorted position map
|
||||
"""
|
||||
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
|
||||
include_set = set(include_list) if include_list else set(positions)
|
||||
exclude_set = set(exclude_list) if exclude_list else set()
|
||||
|
||||
# Add pins to position map
|
||||
position_map = {name: idx for idx, name in enumerate(pin_list)}
|
||||
|
||||
# Add remaining positions to position map, respecting include and exclude lists
|
||||
# Add remaining positions to position map
|
||||
start_idx = len(position_map)
|
||||
for name in positions:
|
||||
if name in position_map:
|
||||
continue # skip pinned items
|
||||
if name in exclude_set:
|
||||
position_map[name] = POSITION_EXCLUDED
|
||||
elif name in include_set:
|
||||
if name not in position_map:
|
||||
position_map[name] = start_idx
|
||||
start_idx += 1
|
||||
else:
|
||||
position_map[name] = POSITION_EXCLUDED
|
||||
|
||||
return position_map
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
:param name_func: the function to get the name of the object
|
||||
:param data: the data to be filtered
|
||||
:return: True if the object should be filtered out, False otherwise
|
||||
"""
|
||||
if not data:
|
||||
return False
|
||||
if not include_set and not exclude_set:
|
||||
return False
|
||||
|
||||
name = name_func(data)
|
||||
|
||||
if name in exclude_set: # exclude_set is prioritized
|
||||
return True
|
||||
if include_set and name not in include_set: # filter out only if include_set is not empty
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
|
@ -102,17 +115,7 @@ def sort_by_position_map(
|
|||
if not position_map or not data:
|
||||
return data
|
||||
|
||||
# filter out the items that are marked "excluded" in the position map
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
name = name_func(item)
|
||||
if name in position_map: # case 1: name is in the position map
|
||||
if position_map[name] != POSITION_EXCLUDED:
|
||||
filtered_data.append(item)
|
||||
else: # case 2: name is not in the position map
|
||||
filtered_data.append(item)
|
||||
|
||||
return sorted(filtered_data, key=lambda x: position_map.get(name_func(x), float('inf')))
|
||||
return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
|
||||
|
||||
|
||||
def sort_to_dict_by_position_map(
|
||||
|
|
|
@ -37,8 +37,3 @@
|
|||
- siliconflow
|
||||
- perfxcloud
|
||||
- zhinao
|
||||
- novita
|
||||
- sagemaker
|
||||
- leptonai
|
||||
- stepfun
|
||||
- huggingface_tei
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
|
@ -18,6 +19,7 @@ from core.entities.provider_entities import (
|
|||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
|
@ -114,6 +116,16 @@ class ProviderManager:
|
|||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
continue
|
||||
|
||||
provider_name = provider_entity.provider
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||
|
|
|
@ -32,29 +32,3 @@
|
|||
- feishu_base
|
||||
- slack
|
||||
- tianditu
|
||||
- hap
|
||||
- websearch
|
||||
- devdocs
|
||||
- regex
|
||||
- getimgai
|
||||
- gitlab
|
||||
- cogview
|
||||
- json_process
|
||||
- firecrawl
|
||||
- google_translate
|
||||
- stackexchange
|
||||
- brave
|
||||
- novitaaai
|
||||
- vanna
|
||||
- twilio
|
||||
- openweather
|
||||
- spider
|
||||
- judge0ce
|
||||
- spark
|
||||
- tavirly
|
||||
- did
|
||||
- stepfun
|
||||
- trello
|
||||
- aws
|
||||
- novitaai
|
||||
- tavily
|
||||
|
|
|
@ -10,14 +10,11 @@ from configs import dify_config
|
|||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
|
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
|
|||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
|
@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
|
@ -107,7 +102,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
|
@ -346,7 +341,7 @@ class ToolManager:
|
|||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
|
@ -414,6 +409,15 @@ class ToolManager:
|
|||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
|
@ -473,7 +477,7 @@ class ToolManager:
|
|||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
|
@ -593,4 +597,5 @@ class ToolManager:
|
|||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
|
|
@ -401,7 +401,7 @@ class ModelProviderService:
|
|||
)
|
||||
) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"get_default_model_of_model_type error: {e}")
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
|
@ -43,14 +45,14 @@ class BuiltinToolManageService:
|
|||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
|
@ -78,7 +80,7 @@ class BuiltinToolManageService:
|
|||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
try:
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
|
@ -119,8 +121,8 @@ class BuiltinToolManageService:
|
|||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
|
@ -135,7 +137,7 @@ class BuiltinToolManageService:
|
|||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
|
@ -156,7 +158,7 @@ class BuiltinToolManageService:
|
|||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
|
@ -165,8 +167,8 @@ class BuiltinToolManageService:
|
|||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
|
@ -179,7 +181,7 @@ class BuiltinToolManageService:
|
|||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
|
@ -202,6 +204,15 @@ class BuiltinToolManageService:
|
|||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
|
@ -226,4 +237,3 @@ class BuiltinToolManageService:
|
|||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_and_filter_position_map
|
||||
from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
|||
- second
|
||||
# - commented
|
||||
- third
|
||||
|
||||
|
||||
- 9999999999999
|
||||
- forth
|
||||
"""))
|
||||
|
@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
|||
"""\
|
||||
# - commented1
|
||||
# - commented2
|
||||
-
|
||||
-
|
||||
|
||||
-
|
||||
-
|
||||
|
||||
"""))
|
||||
return str(tmp_path)
|
||||
|
||||
|
@ -55,45 +55,77 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya
|
|||
assert position_map == {}
|
||||
|
||||
|
||||
def test_excluded_position_map(prepare_example_positions_yaml):
|
||||
def test_excluded_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['forth', 'first']
|
||||
include_list = []
|
||||
exclude_list = ['9999999999999']
|
||||
sorted_filtered_position_map = sort_and_filter_position_map(
|
||||
include_set = set()
|
||||
exclude_set = {'9999999999999'}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list,
|
||||
include_list=include_list,
|
||||
exclude_list=exclude_list
|
||||
pin_list=pin_list
|
||||
)
|
||||
assert sorted_filtered_position_map == {
|
||||
'forth': 0,
|
||||
'first': 1,
|
||||
'second': 2,
|
||||
'third': 3,
|
||||
}
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"9999999999999",
|
||||
"extra1",
|
||||
"extra2",
|
||||
]
|
||||
|
||||
# filter out the data
|
||||
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||
|
||||
# sort data by position map
|
||||
sorted_data = sort_by_position_map(
|
||||
position_map=position_map,
|
||||
data=data,
|
||||
name_func=lambda x: x,
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
|
||||
|
||||
|
||||
def test_included_position_map(prepare_example_positions_yaml):
|
||||
def test_included_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['second', 'first']
|
||||
include_list = ['first', 'second', 'third', 'forth']
|
||||
exclude_list = []
|
||||
sorted_filtered_position_map = sort_and_filter_position_map(
|
||||
pin_list = ['forth', 'first']
|
||||
include_set = {'forth', 'first'}
|
||||
exclude_set = {}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list,
|
||||
include_list=include_list,
|
||||
exclude_list=exclude_list
|
||||
pin_list=pin_list
|
||||
)
|
||||
assert sorted_filtered_position_map == {
|
||||
'second': 0,
|
||||
'first': 1,
|
||||
'third': 2,
|
||||
'forth': 3,
|
||||
}
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"9999999999999",
|
||||
"extra1",
|
||||
"extra2",
|
||||
]
|
||||
|
||||
# filter out the data
|
||||
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||
|
||||
# sort data by position map
|
||||
sorted_data = sort_by_position_map(
|
||||
position_map=position_map,
|
||||
data=data,
|
||||
name_func=lambda x: x,
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first']
|
||||
|
|
Loading…
Reference in New Issue
Block a user