mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Feat/dify billing (#1679)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: takatost <takatost@users.noreply.github.com>
This commit is contained in:
parent
d3a2c0ed34
commit
053102f433
|
@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||||
|
|
||||||
|
# Stripe configuration
|
||||||
STRIPE_API_KEY=
|
STRIPE_API_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# Billing configuration
|
||||||
|
BILLING_API_URL=http://127.0.0.1:8000/v1
|
||||||
|
BILLING_API_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_BILLING_SECRET=
|
|
@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio
|
||||||
|
|
||||||
# Import webhook controllers
|
# Import webhook controllers
|
||||||
from .webhook import stripe
|
from .webhook import stripe
|
||||||
|
|
||||||
|
from .billing import billing
|
||||||
|
|
0
api/controllers/console/billing/__init__.py
Normal file
0
api/controllers/console/billing/__init__.py
Normal file
85
api/controllers/console/billing/billing.py
Normal file
85
api/controllers/console/billing/billing.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import stripe
|
||||||
|
import os
|
||||||
|
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask import current_app, request
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from controllers.console.wraps import only_edition_cloud
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
|
class BillingInfo(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
|
||||||
|
edition = current_app.config['EDITION']
|
||||||
|
if edition != 'CLOUD':
|
||||||
|
return {"enabled": False}
|
||||||
|
|
||||||
|
return BillingService.get_info(current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@only_edition_cloud
|
||||||
|
def get(self):
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
||||||
|
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
class Invoices(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@only_edition_cloud
|
||||||
|
def get(self):
|
||||||
|
|
||||||
|
return BillingService.get_invoices(current_user.email)
|
||||||
|
|
||||||
|
|
||||||
|
class StripeBillingWebhook(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@only_edition_cloud
|
||||||
|
def post(self):
|
||||||
|
payload = request.data
|
||||||
|
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||||
|
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = stripe.Webhook.construct_event(
|
||||||
|
payload, sig_header, webhook_secret
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Invalid payload
|
||||||
|
return 'Invalid payload', 400
|
||||||
|
except stripe.error.SignatureVerificationError as e:
|
||||||
|
# Invalid signature
|
||||||
|
return 'Invalid signature', 400
|
||||||
|
|
||||||
|
BillingService.process_event(event)
|
||||||
|
|
||||||
|
return 'success', 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(BillingInfo, '/billing/info')
|
||||||
|
api.add_resource(Subscription, '/billing/subscription')
|
||||||
|
api.add_resource(Invoices, '/billing/invoices')
|
||||||
|
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')
|
|
@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||||
|
|
||||||
|
|
55
api/services/billing_service.py
Normal file
55
api/services/billing_service.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
|
||||||
|
|
||||||
|
class BillingService:
|
||||||
|
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||||
|
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_info(cls, tenant_id: str):
|
||||||
|
params = {'tenant_id': tenant_id}
|
||||||
|
|
||||||
|
billing_info = cls._send_request('GET', '/info', params=params)
|
||||||
|
|
||||||
|
vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024
|
||||||
|
billing_info['vector_space']['size'] = int(vector_size)
|
||||||
|
|
||||||
|
return billing_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
|
||||||
|
params = {
|
||||||
|
'plan': plan,
|
||||||
|
'interval': interval,
|
||||||
|
'prefilled_email': prefilled_email,
|
||||||
|
'user_name': user_name,
|
||||||
|
'tenant_id': tenant_id
|
||||||
|
}
|
||||||
|
return cls._send_request('GET', '/subscription', params=params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_invoices(cls, prefilled_email: str = ''):
|
||||||
|
params = {'prefilled_email': prefilled_email}
|
||||||
|
return cls._send_request('GET', '/invoices', params=params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Billing-Api-Secret-Key": cls.secret_key
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{cls.base_url}{endpoint}"
|
||||||
|
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def process_event(cls, event: dict):
|
||||||
|
json = {
|
||||||
|
"content": event,
|
||||||
|
}
|
||||||
|
return cls._send_request('POST', '/webhook/stripe', json=json)
|
|
@ -227,6 +227,36 @@ class DatasetService:
|
||||||
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
|
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
|
||||||
.order_by(db.desc(AppDatasetJoin.created_at)).all()
|
.order_by(db.desc(AppDatasetJoin.created_at)).all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_datasets_usage(tenant_id):
|
||||||
|
# get the high_quality datasets
|
||||||
|
dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
|
||||||
|
Dataset.tenant_id == tenant_id).all()
|
||||||
|
if not dataset_ids:
|
||||||
|
return 0
|
||||||
|
dataset_ids = [result[0] for result in dataset_ids]
|
||||||
|
document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
|
||||||
|
Document.tenant_id == tenant_id,
|
||||||
|
Document.completed_at.isnot(None),
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.archived == False
|
||||||
|
).all()
|
||||||
|
if not document_ids:
|
||||||
|
return 0
|
||||||
|
document_ids = [result[0] for result in document_ids]
|
||||||
|
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
|
||||||
|
DocumentSegment.tenant_id == tenant_id,
|
||||||
|
DocumentSegment.completed_at.isnot(None),
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
).all()
|
||||||
|
if not document_segments:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
|
||||||
|
total_vector_size = 1536 * 4 * len(document_segments)
|
||||||
|
|
||||||
|
return total_words_size + total_vector_size
|
||||||
|
|
||||||
|
|
||||||
class DocumentService:
|
class DocumentService:
|
||||||
DEFAULT_RULES = {
|
DEFAULT_RULES = {
|
||||||
|
@ -488,7 +518,8 @@ class DocumentService:
|
||||||
'score_threshold_enabled': False
|
'score_threshold_enabled': False
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
|
||||||
|
'retrieval_model') else default_retrieval_model
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user