Fix/no trial provider (#823)

This commit is contained in:
takatost 2023-08-13 14:56:32 +08:00 committed by GitHub
parent 7898937eae
commit 8e15ba6cd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 4 deletions

View File

@ -168,10 +168,34 @@ class ModelProviderFactory:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType: for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \ if quota_type in model_provider_rules['system_config']['supported_quota_types']:
and quota_type in quota_type_to_provider_dict.keys(): if quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type] provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used: if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
elif quota_type == ProviderQuotaType.TRIAL.value:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.SYSTEM.value,
is_valid=True,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=model_provider_rules['system_config']['quota_limit'],
quota_used=0
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
return provider return provider
no_system_provider = True no_system_provider = True

View File

@ -23,6 +23,14 @@ class ProviderService:
# get rules for all providers # get rules for all providers
model_provider_rules = ModelProviderFactory.get_provider_rules() model_provider_rules = ModelProviderFactory.get_provider_rules()
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()] model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
for model_provider_name, model_provider_rule in model_provider_rules.items():
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
and 'supported_quota_types' in model_provider_rule['system_config'] \
and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
configurable_model_provider_names = [ configurable_model_provider_names = [
model_provider_name model_provider_name
for model_provider_name, model_provider_rules in model_provider_rules.items() for model_provider_name, model_provider_rules in model_provider_rules.items()