From 8e15ba6cd63f50e60a2321fc330042a1aee18182 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 13 Aug 2023 14:56:32 +0800 Subject: [PATCH] Fix/no trial provider (#823) --- .../model_providers/model_provider_factory.py | 32 ++++++++++++++++--- api/services/provider_service.py | 8 +++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index e2d8b43603..6cb2f8fa46 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -168,10 +168,34 @@ class ModelProviderFactory: model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) for quota_type_enum in ProviderQuotaType: quota_type = quota_type_enum.value - if quota_type in model_provider_rules['system_config']['supported_quota_types'] \ - and quota_type in quota_type_to_provider_dict.keys(): - provider = quota_type_to_provider_dict[quota_type] - if provider.is_valid and provider.quota_limit > provider.quota_used: + if quota_type in model_provider_rules['system_config']['supported_quota_types']: + if quota_type in quota_type_to_provider_dict.keys(): + provider = quota_type_to_provider_dict[quota_type] + if provider.is_valid and provider.quota_limit > provider.quota_used: + return provider + elif quota_type == ProviderQuotaType.TRIAL.value: + try: + provider = Provider( + tenant_id=tenant_id, + provider_name=model_provider_name, + provider_type=ProviderType.SYSTEM.value, + is_valid=True, + quota_type=ProviderQuotaType.TRIAL.value, + quota_limit=model_provider_rules['system_config']['quota_limit'], + quota_used=0 + ) + db.session.add(provider) + db.session.commit() + except IntegrityError: + db.session.rollback() + provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value + ).first() + return provider no_system_provider = True diff --git a/api/services/provider_service.py b/api/services/provider_service.py index de8f53d8fc..f061e68d92 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -23,6 +23,14 @@ class ProviderService: # get rules for all providers model_provider_rules = ModelProviderFactory.get_provider_rules() model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()] + + for model_provider_name, model_provider_rule in model_provider_rules.items(): + if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \ + and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \ + and 'supported_quota_types' in model_provider_rule['system_config'] \ + and 'trial' in model_provider_rule['system_config']['supported_quota_types']: + ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + configurable_model_provider_names = [ model_provider_name for model_provider_name, model_provider_rules in model_provider_rules.items()