Feat/dify billing (#1679)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: takatost <takatost@users.noreply.github.com>
This commit is contained in:
Garfield Dai 2023-12-03 20:59:29 +08:00 committed by GitHub
parent d3a2c0ed34
commit 053102f433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 182 additions and 2 deletions

View File

@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
# Stripe configuration
STRIPE_API_KEY= STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET= STRIPE_WEBHOOK_SECRET=
# Billing configuration
BILLING_API_URL=http://127.0.0.1:8000/v1
BILLING_API_SECRET_KEY=
STRIPE_WEBHOOK_BILLING_SECRET=

View File

@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers # Import webhook controllers
from .webhook import stripe from .webhook import stripe
from .billing import billing

View File

@ -0,0 +1,85 @@
import stripe
import os
from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app, request
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import only_edition_cloud
from libs.login import login_required
from services.billing_service import BillingService
class BillingInfo(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
edition = current_app.config['EDITION']
if edition != 'CLOUD':
return {"enabled": False}
return BillingService.get_info(current_user.current_tenant_id)
class Subscription(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
class Invoices(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
return BillingService.get_invoices(current_user.email)
class StripeBillingWebhook(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
BillingService.process_event(event)
return 'success', 200
api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

View File

@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

View File

@ -0,0 +1,55 @@
import os
import requests
from services.dataset_service import DatasetService
class BillingService:
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
@classmethod
def get_info(cls, tenant_id: str):
params = {'tenant_id': tenant_id}
billing_info = cls._send_request('GET', '/info', params=params)
vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024
billing_info['vector_space']['size'] = int(vector_size)
return billing_info
@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
params = {
'plan': plan,
'interval': interval,
'prefilled_email': prefilled_email,
'user_name': user_name,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/subscription', params=params)
@classmethod
def get_invoices(cls, prefilled_email: str = ''):
params = {'prefilled_email': prefilled_email}
return cls._send_request('GET', '/invoices', params=params)
@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Billing-Api-Secret-Key": cls.secret_key
}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
return response.json()
@classmethod
def process_event(cls, event: dict):
json = {
"content": event,
}
return cls._send_request('POST', '/webhook/stripe', json=json)

View File

@ -227,6 +227,36 @@ class DatasetService:
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
.order_by(db.desc(AppDatasetJoin.created_at)).all() .order_by(db.desc(AppDatasetJoin.created_at)).all()
@staticmethod
def get_tenant_datasets_usage(tenant_id):
# get the high_quality datasets
dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
Dataset.tenant_id == tenant_id).all()
if not dataset_ids:
return 0
dataset_ids = [result[0] for result in dataset_ids]
document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
Document.tenant_id == tenant_id,
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False
).all()
if not document_ids:
return 0
document_ids = [result[0] for result in document_ids]
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
DocumentSegment.tenant_id == tenant_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.enabled == True,
).all()
if not document_segments:
return 0
total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
total_vector_size = 1536 * 4 * len(document_segments)
return total_words_size + total_vector_size
class DocumentService: class DocumentService:
DEFAULT_RULES = { DEFAULT_RULES = {
@ -488,7 +518,8 @@ class DocumentService:
'score_threshold_enabled': False 'score_threshold_enabled': False
} }
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
'retrieval_model') else default_retrieval_model
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))