mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Feat/dify billing (#1679)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: takatost <takatost@users.noreply.github.com>
This commit is contained in:
parent
d3a2c0ed34
commit
053102f433
|
@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
|||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
# Stripe configuration
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Billing configuration
|
||||
BILLING_API_URL=http://127.0.0.1:8000/v1
|
||||
BILLING_API_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_BILLING_SECRET=
|
|
@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio
|
|||
|
||||
# Import webhook controllers
|
||||
from .webhook import stripe
|
||||
|
||||
from .billing import billing
|
||||
|
|
0
api/controllers/console/billing/__init__.py
Normal file
0
api/controllers/console/billing/__init__.py
Normal file
85
api/controllers/console/billing/billing.py
Normal file
85
api/controllers/console/billing/billing.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import stripe
|
||||
import os
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user
|
||||
from flask import current_app, request
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class BillingInfo(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
edition = current_app.config['EDITION']
|
||||
if edition != 'CLOUD':
|
||||
return {"enabled": False}
|
||||
|
||||
return BillingService.get_info(current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Subscription(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
|
||||
|
||||
class StripeBillingWebhook(Resource):
|
||||
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
BillingService.process_event(event)
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(BillingInfo, '/billing/info')
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')
|
|
@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
|||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
|
55
api/services/billing_service.py
Normal file
55
api/services/billing_service.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import requests
|
||||
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {'tenant_id': tenant_id}
|
||||
|
||||
billing_info = cls._send_request('GET', '/info', params=params)
|
||||
|
||||
vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024
|
||||
billing_info['vector_space']['size'] = int(vector_size)
|
||||
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
|
||||
params = {
|
||||
'plan': plan,
|
||||
'interval': interval,
|
||||
'prefilled_email': prefilled_email,
|
||||
'user_name': user_name,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/subscription', params=params)
|
||||
|
||||
@classmethod
|
||||
def get_invoices(cls, prefilled_email: str = ''):
|
||||
params = {'prefilled_email': prefilled_email}
|
||||
return cls._send_request('GET', '/invoices', params=params)
|
||||
|
||||
@classmethod
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Billing-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def process_event(cls, event: dict):
|
||||
json = {
|
||||
"content": event,
|
||||
}
|
||||
return cls._send_request('POST', '/webhook/stripe', json=json)
|
|
@ -227,6 +227,36 @@ class DatasetService:
|
|||
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
|
||||
.order_by(db.desc(AppDatasetJoin.created_at)).all()
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_datasets_usage(tenant_id):
|
||||
# get the high_quality datasets
|
||||
dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
|
||||
Dataset.tenant_id == tenant_id).all()
|
||||
if not dataset_ids:
|
||||
return 0
|
||||
dataset_ids = [result[0] for result in dataset_ids]
|
||||
document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False
|
||||
).all()
|
||||
if not document_ids:
|
||||
return 0
|
||||
document_ids = [result[0] for result in document_ids]
|
||||
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.enabled == True,
|
||||
).all()
|
||||
if not document_segments:
|
||||
return 0
|
||||
|
||||
total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
|
||||
total_vector_size = 1536 * 4 * len(document_segments)
|
||||
|
||||
return total_words_size + total_vector_size
|
||||
|
||||
|
||||
class DocumentService:
|
||||
DEFAULT_RULES = {
|
||||
|
@ -488,7 +518,8 @@ class DocumentService:
|
|||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
|
||||
'retrieval_model') else default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
|
|
Loading…
Reference in New Issue
Block a user