feat: support pinning, including, and excluding for model providers and tools (#7419)

Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
Xiyuan Chen 2024-08-20 23:16:43 -04:00 committed by GitHub
parent 6c25d7bed3
commit 4e7b6aec3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 363 additions and 57 deletions

View File

@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -466,6 +524,7 @@ class FeatureConfig(
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
Overall logic: exclude > include > pin
:param position_map: the position map to be sorted and filtered
:param pin_list: the list of pins to be put at the beginning
:return: the sorted position map
"""
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
# Add pins to position map
position_map = {name: idx for idx, name in enumerate(pin_list)}
# Add remaining positions to position map
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded
:param name_func: the function to get the name of the object
:param data: the data to be filtered
:return: True if the object should be filtered out, False otherwise
"""
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set is prioritized
return True
if include_set and name not in include_set: # filter out only if include_set is not empty
return True
return False
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],

View File

@ -368,6 +368,15 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Return first provider and the first model in the provider
:param tenant_id: tenant id
:param model_type: model type
:return: provider name, model name
"""
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
@ -502,7 +511,6 @@ class LBModelManager:
config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res

View File

@ -151,9 +151,9 @@ class AIModel(ABC):
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
]
# get _position.yaml file path

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -234,7 +234,7 @@ class ModelProviderFactory:
]
# get _position.yaml file path
position_map = get_position_map(model_providers_path)
position_map = get_provider_position_map(model_providers_path)
# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []

View File

@ -5,6 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@ -45,6 +43,7 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
@ -117,6 +116,16 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
continue
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@ -271,6 +280,24 @@ class ProviderManager:
)
)
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Get names of first model and its provider
:param tenant_id: workspace id
:param model_type: model type
:return: provider name, model name
"""
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
all_models = provider_configurations.get_models(
model_type=model_type,
only_active=False
)
return all_models[0].provider.provider, all_models[0].model
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel:
"""

View File

@ -1,6 +1,6 @@
import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
def name_func(provider: UserToolProvider) -> str:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers
return sorted_providers

View File

@ -10,14 +10,11 @@ from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
)
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
@ -107,7 +102,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@ -346,7 +341,7 @@ class ToolManager:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
@ -414,6 +409,15 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name),
@ -473,7 +477,7 @@ class ToolManager:
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiToolProviderController, dict[str, Any]]:
ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
@ -593,4 +597,5 @@ class ToolManager:
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()

View File

@ -111,6 +111,12 @@ class AppService:
'completion_params': {}
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id,
model_type=ModelType.LLM
)
default_model_config['model']['provider'] = provider
default_model_config['model']['name'] = model
default_model_dict = default_model_config['model']
default_model_config['model'] = json.dumps(default_model_dict)
@ -190,13 +196,14 @@ class AppService:
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app

View File

@ -30,6 +30,7 @@ class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
@ -387,18 +388,21 @@ class ModelProviderService:
tenant_id=tenant_id,
model_type=model_type_enum
)
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
try:
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
except Exception as e:
logger.info(f"get_default_model_of_model_type error: {e}")
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
"""

View File

@ -1,6 +1,8 @@
import json
import logging
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
@ -43,14 +45,14 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
@ -78,7 +80,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.provider == provider_name,
).first()
try:
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
@ -119,8 +121,8 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
@ -135,7 +137,7 @@ class BuiltinToolManageService:
if provider is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@ -156,7 +158,7 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
@ -165,8 +167,8 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
@ -179,7 +181,7 @@ class BuiltinToolManageService:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
@ -202,6 +204,15 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@ -226,4 +237,3 @@ class BuiltinToolManageService:
raise e
return BuiltinToolProviderSort.sort(result)

View File

@ -2,7 +2,7 @@ from textwrap import dedent
import pytest
from core.helper.position_helper import get_position_map
from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
@pytest.fixture
@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
- second
# - commented
- third
- 9999999999999
- forth
"""))
@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
"""\
# - commented1
# - commented2
-
-
-
-
"""))
return str(tmp_path)
@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya
folder_path=prepare_empty_commented_positions_yaml,
file_name='example_positions_all_commented.yaml')
assert position_map == {}
def test_excluded_position_data(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml'
)
pin_list = ['forth', 'first']
include_set = set()
exclude_set = {'9999999999999'}
position_map = pin_position_map(
original_position_map=position_map,
pin_list=pin_list
)
data = [
"forth",
"first",
"second",
"third",
"9999999999999",
"extra1",
"extra2",
]
# filter out the data
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
# sort data by position map
sorted_data = sort_by_position_map(
position_map=position_map,
data=data,
name_func=lambda x: x,
)
# assert the result in the correct order
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
def test_included_position_data(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml'
)
pin_list = ['forth', 'first']
include_set = {'forth', 'first'}
exclude_set = {}
position_map = pin_position_map(
original_position_map=position_map,
pin_list=pin_list
)
data = [
"forth",
"first",
"second",
"third",
"9999999999999",
"extra1",
"extra2",
]
# filter out the data
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
# sort data by position map
sorted_data = sort_by_position_map(
position_map=position_map,
data=data,
name_func=lambda x: x,
)
# assert the result in the correct order
assert sorted_data == ['forth', 'first']

View File

@ -701,3 +701,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate}
# ------------------------------
EXPOSE_NGINX_PORT=80
EXPOSE_NGINX_SSL_PORT=443
# ----------------------------------------------------------------------------
# ModelProvider & Tool Position Configuration
# Used to specify the model providers and tools that can be used in the app.
# ----------------------------------------------------------------------------
# Pin, include, and exclude tools
# Use comma-separated values with no spaces between items.
# Example: POSITION_TOOL_PINS=bing,google
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
# Pin, include, and exclude model providers
# Use comma-separated values with no spaces between items.
# Example: POSITION_PROVIDER_PINS=openai,openllm
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=