mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Fix/no trial provider (#823)
This commit is contained in:
parent
7898937eae
commit
8e15ba6cd6
|
@ -168,10 +168,34 @@ class ModelProviderFactory:
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||||
for quota_type_enum in ProviderQuotaType:
|
for quota_type_enum in ProviderQuotaType:
|
||||||
quota_type = quota_type_enum.value
|
quota_type = quota_type_enum.value
|
||||||
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
|
if quota_type in model_provider_rules['system_config']['supported_quota_types']:
|
||||||
and quota_type in quota_type_to_provider_dict.keys():
|
if quota_type in quota_type_to_provider_dict.keys():
|
||||||
provider = quota_type_to_provider_dict[quota_type]
|
provider = quota_type_to_provider_dict[quota_type]
|
||||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
||||||
|
return provider
|
||||||
|
elif quota_type == ProviderQuotaType.TRIAL.value:
|
||||||
|
try:
|
||||||
|
provider = Provider(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_name=model_provider_name,
|
||||||
|
provider_type=ProviderType.SYSTEM.value,
|
||||||
|
is_valid=True,
|
||||||
|
quota_type=ProviderQuotaType.TRIAL.value,
|
||||||
|
quota_limit=model_provider_rules['system_config']['quota_limit'],
|
||||||
|
quota_used=0
|
||||||
|
)
|
||||||
|
db.session.add(provider)
|
||||||
|
db.session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
provider = db.session.query(Provider) \
|
||||||
|
.filter(
|
||||||
|
Provider.tenant_id == tenant_id,
|
||||||
|
Provider.provider_name == model_provider_name,
|
||||||
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||||
|
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||||
|
).first()
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
no_system_provider = True
|
no_system_provider = True
|
||||||
|
|
|
@ -23,6 +23,14 @@ class ProviderService:
|
||||||
# get rules for all providers
|
# get rules for all providers
|
||||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||||
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
|
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
|
||||||
|
|
||||||
|
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||||
|
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
|
||||||
|
and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
|
||||||
|
and 'supported_quota_types' in model_provider_rule['system_config'] \
|
||||||
|
and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
|
||||||
|
ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||||
|
|
||||||
configurable_model_provider_names = [
|
configurable_model_provider_names = [
|
||||||
model_provider_name
|
model_provider_name
|
||||||
for model_provider_name, model_provider_rules in model_provider_rules.items()
|
for model_provider_name, model_provider_rules in model_provider_rules.items()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user