From 430e10014288cf6a7f1852945618835976ffa3dd Mon Sep 17 00:00:00 2001 From: Shota Totsuka <153569547+totsukash@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:45:03 +0900 Subject: [PATCH] refactor: Add @staticmethod decorator in `api/core` (#7652) --- api/core/hosting_configuration.py | 18 ++++++++++++------ api/core/indexing_runner.py | 27 ++++++++++++++++++--------- api/core/model_manager.py | 10 ++++++---- api/core/provider_manager.py | 24 ++++++++++++++++-------- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 5f7fec5833..ddcd751286 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -58,7 +58,8 @@ class HostingConfiguration: self.moderation_config = self.init_moderation_config(config) - def init_azure_openai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_azure_openai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TIMES if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): credentials = { @@ -145,7 +146,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_anthropic(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_anthropic(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas = [] @@ -180,7 +182,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_minimax(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_minimax(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_MINIMAX_ENABLED"): quotas = [FreeHostingQuota()] @@ -197,7 +200,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_spark(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_spark(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_SPARK_ENABLED"): quotas = [FreeHostingQuota()] @@ -214,7 +218,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_zhipuai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_zhipuai(app_config: Config) -> HostingProvider: quota_unit = QuotaUnit.TOKENS if app_config.get("HOSTED_ZHIPUAI_ENABLED"): quotas = [FreeHostingQuota()] @@ -231,7 +236,8 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: + @staticmethod + def init_moderation_config(app_config: Config) -> HostedModerationConfig: if app_config.get("HOSTED_MODERATION_ENABLED") \ and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 8173028ed7..dddf5567c1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -411,7 +411,8 @@ class IndexingRunner: return text_docs - def filter_string(self, text): + @staticmethod + def filter_string(text): text = re.sub(r'<\|', '<', text) text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) @@ -419,7 +420,8 @@ class IndexingRunner: text = re.sub('\uFFFE', '', text) return text - def _get_splitter(self, processing_rule: DatasetProcessRule, + @staticmethod + def _get_splitter(processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -611,7 +613,8 @@ class IndexingRunner: return all_documents - def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: + @staticmethod + def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: """ Clean the document text according to the processing rules. """ @@ -640,7 +643,8 @@ class IndexingRunner: return text - def format_split_text(self, text): + @staticmethod + def format_split_text(text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) @@ -704,7 +708,8 @@ class IndexingRunner: } ) - def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): + @staticmethod + def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: @@ -758,13 +763,15 @@ class IndexingRunner: return tokens - def _check_document_paused_status(self, document_id: str): + @staticmethod + def _check_document_paused_status(document_id: str): indexing_cache_key = 'document_{}_is_paused'.format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() - def _update_document_index_status(self, document_id: str, after_indexing_status: str, + @staticmethod + def _update_document_index_status(document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None) -> None: """ Update the document indexing status. @@ -786,14 +793,16 @@ class IndexingRunner: DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() - def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: + @staticmethod + def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): + @staticmethod + def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1ceed8043c..7b1a7ada5b 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -44,7 +44,8 @@ class ModelInstance: credentials=self.credentials ) - def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: + @staticmethod + def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -63,7 +64,8 @@ class ModelInstance: return credentials - def _get_load_balancing_manager(self, configuration: ProviderConfiguration, + @staticmethod + def _get_load_balancing_manager(configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict) -> Optional["LBModelManager"]: @@ -515,8 +517,8 @@ class LBModelManager: res = cast(bool, res) return res - @classmethod - def get_config_in_cooldown_and_ttl(cls, tenant_id: str, + @staticmethod + def get_config_in_cooldown_and_ttl(tenant_id: str, provider: str, model_type: ModelType, model: str, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6c68cee7be..67eee2c294 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -350,7 +350,8 @@ class ProviderManager: return default_model - def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: + @staticmethod + def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: """ Get all provider records of the workspace. @@ -369,7 +370,8 @@ class ProviderManager: return provider_name_to_provider_records_dict - def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: + @staticmethod + def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. @@ -389,7 +391,8 @@ class ProviderManager: return provider_name_to_provider_model_records_dict - def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: + @staticmethod + def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. @@ -408,7 +411,8 @@ class ProviderManager: return provider_name_to_preferred_provider_type_records_dict - def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + @staticmethod + def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. @@ -427,7 +431,8 @@ class ProviderManager: return provider_name_to_provider_model_settings_dict - def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + @staticmethod + def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. @@ -458,7 +463,8 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict - def _init_trial_provider_records(self, tenant_id: str, + @staticmethod + def _init_trial_provider_records(tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -791,7 +797,8 @@ class ProviderManager: credentials=current_using_credentials ) - def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: + @staticmethod + def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: """ Choice current using quota type. paid quotas > provider free quotas > hosting trial quotas @@ -818,7 +825,8 @@ class ProviderManager: raise ValueError('No quota type available') - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + @staticmethod + def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables.