diff --git a/api/.env.example b/api/.env.example index 012c8a5c65..90abb1ef00 100644 --- a/api/.env.example +++ b/api/.env.example @@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE= HOSTED_OPENAI_API_ORGANIZATION= HOSTED_OPENAI_QUOTA_LIMIT=200 HOSTED_OPENAI_PAID_ENABLED=false -HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= -HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 HOSTED_AZURE_OPENAI_ENABLED=false HOSTED_AZURE_OPENAI_API_KEY= @@ -119,16 +117,7 @@ HOSTED_ANTHROPIC_API_BASE= HOSTED_ANTHROPIC_API_KEY= HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 HOSTED_ANTHROPIC_PAID_ENABLED=false -HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= -HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 -HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 -HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 - -# Stripe configuration -STRIPE_API_KEY= -STRIPE_WEBHOOK_SECRET= # Billing configuration BILLING_API_URL=http://127.0.0.1:8000/v1 -BILLING_API_SECRET_KEY= -STRIPE_WEBHOOK_BILLING_SECRET= \ No newline at end of file +BILLING_API_SECRET_KEY= \ No newline at end of file diff --git a/api/app.py b/api/app.py index a0a09994df..487cd45fb4 100644 --- a/api/app.py +++ b/api/app.py @@ -20,7 +20,7 @@ from flask_cors import CORS from core.model_providers.providers import hosted from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ - ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension + ext_database, ext_storage, ext_mail, ext_code_based_extension from extensions.ext_database import db from extensions.ext_login import login_manager @@ -96,7 +96,6 @@ def initialize_extensions(app): ext_login.init_app(app) ext_mail.init_app(app) ext_sentry.init_app(app) - ext_stripe.init_app(app) # Flask-Login configuration diff --git a/api/config.py b/api/config.py index cdf1ac066e..e07af3ef21 100644 --- a/api/config.py +++ b/api/config.py @@ -1,11 +1,8 @@ # -*- coding:utf-8 -*- import os -from datetime import timedelta import dotenv -from extensions.ext_database import db -from extensions.ext_redis import redis_client dotenv.load_dotenv() @@ -44,15 +41,11 @@ DEFAULTS = { 'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_ENABLED': 'False', 'HOSTED_OPENAI_PAID_ENABLED': 'False', - 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, 'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, 'HOSTED_ANTHROPIC_ENABLED': 'False', 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', - 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, - 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, - 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, 'HOSTED_MODERATION_ENABLED': 'False', 'HOSTED_MODERATION_PROVIDERS': '', 'CLEAN_DAY_SETTING': 30, @@ -268,8 +261,6 @@ class Config: self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') - self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') - self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') @@ -281,10 +272,6 @@ class Config: self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') - self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') - self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')) - self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) - self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') @@ -302,6 +289,3 @@ class CloudEditionConfig(Config): self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID') self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') - - self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') - self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 99d677970c..46fa0e79dc 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m # Import universal chat controllers from .universal_chat import chat, conversation, message, parameter, audio -# Import webhook controllers -from .webhook import stripe - from .billing import billing diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 6bad91f411..e4bb77dca8 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,9 +1,6 @@ -import stripe -import os - from flask_restful import Resource, reqparse from flask_login import current_user -from flask import current_app, request +from flask import current_app from controllers.console import api from controllers.console.setup import setup_required @@ -40,7 +37,10 @@ class Subscription(Resource): parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) args = parser.parse_args() - return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) + return BillingService.get_subscription(args['plan'], + args['interval'], + current_user.email, + current_user.current_tenant_id) class Invoices(Resource): @@ -54,32 +54,6 @@ class Invoices(Resource): return BillingService.get_invoices(current_user.email) -class StripeBillingWebhook(Resource): - - @setup_required - @only_edition_cloud - def post(self): - payload = request.data - sig_header = request.headers.get('STRIPE_SIGNATURE') - webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') - - try: - event = stripe.Webhook.construct_event( - payload, sig_header, webhook_secret - ) - except ValueError as e: - # Invalid payload - return 'Invalid payload', 400 - except stripe.error.SignatureVerificationError as e: - # Invalid signature - return 'Invalid signature', 400 - - BillingService.process_event(event) - - return 'success', 200 - - api.add_resource(BillingInfo, '/billing/info') api.add_resource(Subscription, '/billing/subscription') api.add_resource(Invoices, '/billing/invoices') -api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') diff --git a/api/controllers/console/webhook/__init__.py b/api/controllers/console/webhook/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/controllers/console/webhook/stripe.py b/api/controllers/console/webhook/stripe.py deleted file mode 100644 index 15cce34723..0000000000 --- a/api/controllers/console/webhook/stripe.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -import stripe -from flask import request, current_app -from flask_restful import Resource - -from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import only_edition_cloud -from services.provider_checkout_service import ProviderCheckoutService - - -class StripeWebhookApi(Resource): - @setup_required - @only_edition_cloud - def post(self): - payload = request.data - sig_header = request.headers.get('STRIPE_SIGNATURE') - webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') - - try: - event = stripe.Webhook.construct_event( - payload, sig_header, webhook_secret - ) - except ValueError as e: - # Invalid payload - return 'Invalid payload', 400 - except stripe.error.SignatureVerificationError as e: - # Invalid signature - return 'Invalid signature', 400 - - # Handle the checkout.session.completed event - if event['type'] == 'checkout.session.completed': - logging.debug(event['data']['object']['id']) - logging.debug(event['data']['object']['amount_subtotal']) - logging.debug(event['data']['object']['currency']) - logging.debug(event['data']['object']['payment_intent']) - logging.debug(event['data']['object']['payment_status']) - logging.debug(event['data']['object']['metadata']) - - session = stripe.checkout.Session.retrieve( - event['data']['object']['id'], - expand=['line_items'], - ) - - logging.debug(session.line_items['data'][0]['quantity']) - - # Fulfill the purchase... - provider_checkout_service = ProviderCheckoutService() - - try: - provider_checkout_service.fulfill_provider_order(event, session.line_items) - except Exception as e: - - logging.debug(str(e)) - return 'success', 200 - - return 'success', 200 - - -api.add_resource(StripeWebhookApi, '/webhook/stripe') diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0cfa8d17dd..e92a07b2ad 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -9,8 +9,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_providers.error import LLMBadRequestError from core.model_providers.providers.base import CredentialsValidateFailedError -from services.provider_checkout_service import ProviderCheckoutService from services.provider_service import ProviderService +from services.billing_service import BillingService class ModelProviderListApi(Resource): @@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider_name: str): - provider_service = ProviderCheckoutService() - provider_checkout = provider_service.create_checkout( - tenant_id=current_user.current_tenant_id, - provider_name=provider_name, - account=current_user - ) + if provider_name != 'anthropic': + raise ValueError(f'provider name {provider_name} is invalid') - return { - 'url': provider_checkout.get_checkout_url() - } + data = BillingService.get_model_provider_payment_link(provider_name=provider_name, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id) + return data class ModelProviderFreeQuotaSubmitApi(Resource): diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index c98a56e510..7e9c383de5 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider): return False - def get_payment_info(self) -> Optional[dict]: - """ - get product info if it payable. - - :return: - """ - if hosted_model_providers.anthropic \ - and hosted_model_providers.anthropic.paid_enabled: - return { - 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id, - 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota, - 'min_quantity': hosted_model_providers.anthropic.paid_min_quantity, - 'max_quantity': hosted_model_providers.anthropic.paid_max_quantity, - } - - return None - @classmethod def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): """ diff --git a/api/core/model_providers/providers/base.py b/api/core/model_providers/providers/base.py index 9b05b4f5fd..1ff03f286d 100644 --- a/api/core/model_providers/providers/base.py +++ b/api/core/model_providers/providers/base.py @@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC): ).update({'last_used': datetime.utcnow()}) db.session.commit() - def get_payment_info(self) -> Optional[dict]: - """ - get product info if it payable. - - :return: - """ - return None - def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel: """ get provider model. diff --git a/api/core/model_providers/providers/hosted.py b/api/core/model_providers/providers/hosted.py index fd90a0a360..df8a769acb 100644 --- a/api/core/model_providers/providers/hosted.py +++ b/api/core/model_providers/providers/hosted.py @@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel): quota_limit: int = 0 """Quota limit for the openai hosted model. -1 means unlimited.""" paid_enabled: bool = False - paid_stripe_price_id: str = None - paid_increase_quota: int = 1 class HostedAzureOpenAI(BaseModel): @@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel): quota_limit: int = 0 """Quota limit for the anthropic hosted model. -1 means unlimited.""" paid_enabled: bool = False - paid_stripe_price_id: str = None - paid_increase_quota: int = 1000000 - paid_min_quantity: int = 20 - paid_max_quantity: int = 100 class HostedModelProviders(BaseModel): @@ -68,8 +62,6 @@ def init_app(app: Flask): api_key=app.config.get("HOSTED_OPENAI_API_KEY"), quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"), paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"), - paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), - paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"), ) if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"): @@ -85,10 +77,6 @@ def init_app(app: Flask): api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"), quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"), paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), - paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), - paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"), - paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), - paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), ) if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): diff --git a/api/core/model_providers/providers/openai_provider.py b/api/core/model_providers/providers/openai_provider.py index 1802302fa3..0e890529c6 100644 --- a/api/core/model_providers/providers/openai_provider.py +++ b/api/core/model_providers/providers/openai_provider.py @@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider): return False - def get_payment_info(self) -> Optional[dict]: - """ - get payment info if it payable. - - :return: - """ - if hosted_model_providers.openai \ - and hosted_model_providers.openai.paid_enabled: - return { - 'product_id': hosted_model_providers.openai.paid_stripe_price_id, - 'increase_quota': hosted_model_providers.openai.paid_increase_quota, - } - - return None - @classmethod def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): """ diff --git a/api/extensions/ext_stripe.py b/api/extensions/ext_stripe.py deleted file mode 100644 index 3a192c081a..0000000000 --- a/api/extensions/ext_stripe.py +++ /dev/null @@ -1,6 +0,0 @@ -import stripe - - -def init_app(app): - if app.config.get('STRIPE_API_KEY'): - stripe.api_key = app.config.get('STRIPE_API_KEY') diff --git a/api/models/provider.py b/api/models/provider.py index 63e9785a96..1ce8e35614 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) -class ProviderOrderPaymentStatus(Enum): - WAIT_PAY = 'wait_pay' - PAID = 'paid' - PAY_FAILED = 'pay_failed' - REFUNDED = 'refunded' - - @staticmethod - def value_of(value): - for member in ProviderOrderPaymentStatus: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - - class ProviderOrder(db.Model): __tablename__ = 'provider_orders' __table_args__ = ( diff --git a/api/requirements.txt b/api/requirements.txt index ca26601e0e..4224ca6c5c 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -46,7 +46,6 @@ websocket-client~=1.6.1 dashscope~=1.11.0 huggingface_hub~=0.16.4 transformers~=4.31.0 -stripe~=5.5.0 pandas==1.5.3 xinference-client~=0.6.4 safetensors==0.3.2 diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 2f425a61c8..865e8e339a 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -10,7 +10,7 @@ class BillingService: def get_info(cls, tenant_id: str): params = {'tenant_id': tenant_id} - billing_info = cls._send_request('GET', '/info', params=params) + billing_info = cls._send_request('GET', '/subscription/info', params=params) return billing_info @@ -18,16 +18,26 @@ class BillingService: def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', - user_name: str = '', tenant_id: str = ''): params = { 'plan': plan, 'interval': interval, 'prefilled_email': prefilled_email, - 'user_name': user_name, 'tenant_id': tenant_id } - return cls._send_request('GET', '/subscription', params=params) + return cls._send_request('GET', '/subscription/payment-link', params=params) + + @classmethod + def get_model_provider_payment_link(cls, + provider_name: str, + tenant_id: str, + account_id: str): + params = { + 'provider_name': provider_name, + 'tenant_id': tenant_id, + 'account_id': account_id + } + return cls._send_request('GET', '/model-provider/payment-link', params=params) @classmethod def get_invoices(cls, prefilled_email: str = ''): @@ -45,10 +55,3 @@ class BillingService: response = requests.request(method, url, json=json, params=params, headers=headers) return response.json() - - @classmethod - def process_event(cls, event: dict): - json = { - "content": event, - } - return cls._send_request('POST', '/webhook/stripe', json=json) diff --git a/api/services/provider_checkout_service.py b/api/services/provider_checkout_service.py deleted file mode 100644 index 4268acf657..0000000000 --- a/api/services/provider_checkout_service.py +++ /dev/null @@ -1,174 +0,0 @@ -import datetime -import logging - -import stripe -from flask import current_app - -from core.model_providers.model_provider_factory import ModelProviderFactory -from extensions.ext_database import db -from models.account import Account -from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType - - -class ProviderCheckout: - def __init__(self, stripe_checkout_session): - self.stripe_checkout_session = stripe_checkout_session - - def get_checkout_url(self): - return self.stripe_checkout_session.url - - -class ProviderCheckoutService: - def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout: - # check provider name is valid - model_provider_rules = ModelProviderFactory.get_provider_rules() - if provider_name not in model_provider_rules: - raise ValueError(f'provider name {provider_name} is invalid') - - model_provider_rule = model_provider_rules[provider_name] - - # check provider name can be paid - self._check_provider_payable(provider_name, model_provider_rule) - - # get stripe checkout product id - paid_provider = self._get_paid_provider(tenant_id, provider_name) - model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) - model_provider = model_provider_class(provider=paid_provider) - payment_info = model_provider.get_payment_info() - if not payment_info: - raise ValueError(f'provider name {provider_name} not support payment') - - payment_product_id = payment_info['product_id'] - payment_min_quantity = payment_info['min_quantity'] - payment_max_quantity = payment_info['max_quantity'] - - # create provider order - provider_order = ProviderOrder( - tenant_id=tenant_id, - provider_name=provider_name, - account_id=account.id, - payment_product_id=payment_product_id, - quantity=1, - payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value - ) - - db.session.add(provider_order) - db.session.flush() - - line_item = { - 'price': f'{payment_product_id}', - 'quantity': payment_min_quantity - } - - if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity: - line_item['adjustable_quantity'] = { - 'enabled': True, - 'minimum': payment_min_quantity, - 'maximum': payment_max_quantity - } - - try: - # create stripe checkout session - checkout_session = stripe.checkout.Session.create( - line_items=[ - line_item - ], - mode='payment', - success_url=current_app.config.get("CONSOLE_WEB_URL") - + f'?provider_name={provider_name}&payment_result=succeeded', - cancel_url=current_app.config.get("CONSOLE_WEB_URL") - + f'?provider_name={provider_name}&payment_result=cancelled', - automatic_tax={'enabled': True}, - ) - except Exception as e: - logging.exception(e) - raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later') - - provider_order.payment_id = checkout_session.id - db.session.commit() - - return ProviderCheckout(checkout_session) - - def fulfill_provider_order(self, event, line_items): - provider_order = db.session.query(ProviderOrder) \ - .filter(ProviderOrder.payment_id == event['data']['object']['id']) \ - .first() - - if not provider_order: - raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}') - - if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value: - raise ValueError( - f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') - - provider_order.transaction_id = event['data']['object']['payment_intent'] - provider_order.currency = event['data']['object']['currency'] - provider_order.total_amount = event['data']['object']['amount_subtotal'] - provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value - provider_order.paid_at = datetime.datetime.utcnow() - provider_order.updated_at = provider_order.paid_at - - # update provider quota - provider = db.session.query(Provider).filter( - Provider.tenant_id == provider_order.tenant_id, - Provider.provider_name == provider_order.provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.PAID.value - ).first() - - if not provider: - raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, ' - f'provider name: {provider_order.provider_name}') - - model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name) - model_provider = model_provider_class(provider=provider) - payment_info = model_provider.get_payment_info() - - quantity = line_items['data'][0]['quantity'] - - if not payment_info: - increase_quota = 0 - else: - increase_quota = int(payment_info['increase_quota']) * quantity - - if increase_quota > 0: - provider.quota_limit += increase_quota - provider.is_valid = True - - db.session.commit() - - def _check_provider_payable(self, provider_name: str, model_provider_rule: dict): - if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']: - raise ValueError(f'provider name {provider_name} not support payment') - - if 'system_config' not in model_provider_rule: - raise ValueError(f'provider name {provider_name} not support payment') - - if 'supported_quota_types' not in model_provider_rule['system_config']: - raise ValueError(f'provider name {provider_name} not support payment') - - if 'paid' not in model_provider_rule['system_config']['supported_quota_types']: - raise ValueError(f'provider name {provider_name} not support payment') - - def _get_paid_provider(self, tenant_id: str, provider_name: str): - paid_provider = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.PAID.value, - ).first() - - if not paid_provider: - paid_provider = Provider( - tenant_id=tenant_id, - provider_name=provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.PAID.value, - quota_limit=0, - quota_used=0, - ) - db.session.add(paid_provider) - db.session.commit() - - return paid_provider