diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index aabc417759..823341ec42 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -14,7 +14,7 @@ from controllers.console.workspace.error import ( InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone @@ -78,6 +78,7 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required + @enterprise_license_required @marshal_with(account_fields) def get(self): return current_user diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 9e13c7b924..5fbefd0cc5 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -35,3 +35,9 @@ class AccountNotInitializedError(BaseHTTPException): error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 + + +class EnterpriseLicenseUnauthorized(BaseHTTPException): + error_code = "unauthorized" + description = "Your license is invalid. Please contact your administrator." + code = 401 diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 9f294cb93c..55e46681b0 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -6,9 +6,9 @@ from flask import abort, request from flask_login import current_user from configs import dify_config -from controllers.console.workspace.error import AccountNotInitializedError +from controllers.console.workspace.error import AccountNotInitializedError, EnterpriseLicenseUnauthorized from models.model import DifySetup -from services.feature_service import FeatureService +from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService from .error import NotInitValidateError, NotSetupError @@ -142,3 +142,15 @@ def setup_required(view): return view(*args, **kwargs) return decorated + + +def enterprise_license_required(view): + @wraps(view) + def decorated(*args, **kwargs): + settings = FeatureService.get_system_features() + if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: + raise EnterpriseLicenseUnauthorized() + + return view(*args, **kwargs) + + return decorated diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 0fde6f82d8..d0b04628cf 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,3 +1,5 @@ +from enum import Enum + from pydantic import BaseModel, ConfigDict from configs import dify_config @@ -20,8 +22,17 @@ class LimitationModel(BaseModel): limit: int = 0 +class LicenseStatus(str, Enum): + NONE = "none" + INACTIVE = "inactive" + ACTIVE = "active" + EXPIRING = "expiring" + EXPIRED = "expired" + LOST = "lost" + + class LicenseModel(BaseModel): - status: str = "none" + status: LicenseStatus = LicenseStatus.NONE expired_at: str = "" @@ -164,4 +175,4 @@ class FeatureService: features.license.status = enterprise_info["license"]["status"] if "expired_at" in enterprise_info["license"]: - features.license.expired_at = enterprise_info["license"]["expired_at"] \ No newline at end of file + features.license.expired_at = enterprise_info["license"]["expired_at"]