diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 7b65898956..2d5f95aef0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -599,9 +599,9 @@ class PositionConfig(BaseSettings): class VerificationConfig(BaseSettings): - VERIFICATION_CODE_EXPIRY: PositiveInt = Field( - description="Duration in seconds for which a verification code remains valid", - default=300, + VERIFICATION_CODE_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a verification code remains valid", + default=5, ) VERIFICATION_CODE_LENGTH: PositiveInt = Field( @@ -609,9 +609,9 @@ class VerificationConfig(BaseSettings): default=6, ) - VERIFICATION_CODE_COOLDOWN: PositiveInt = Field( - description="Cooldown time in seconds between verification code generation", - default=60, + VERIFICATION_CODE_COOLDOWN_MINUTES: PositiveInt = Field( + description="Cooldown time in minutes between verification code generation", + default=1, ) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 74799c8cf4..9c70c7dfe9 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -248,17 +248,20 @@ class AccountDeleteVerifyApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + def post(self): account = current_user try: - code = VerificationService.generate_account_deletion_verification_code(account.email) + code = VerificationService.generate_account_deletion_verification_code(account) AccountService.send_account_delete_verification_email(account, code) except RateLimitExceededError: return {"result": "fail", "error": "Rate limit exceeded."}, 429 return {"result": "success"} + +class AccountDeleteApi(Resource): + @setup_required @login_required @account_initialization_required @@ -267,13 +270,29 @@ class AccountDeleteVerifyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("code", type=str, required=True, location="json") + args = parser.parse_args() + + if not VerificationService.verify_account_deletion_verification_code(account, args["code"]): + return {"result": "fail", "error": "Verification code is invalid."}, 400 + + AccountService.delete_account(account) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def patch(self): + account = current_user + + parser = reqparse.RequestParser() parser.add_argument("reason", type=str, required=True, location="json") args = parser.parse_args() - if not VerificationService.verify_account_deletion_verification_code(account.email, args["code"]): - return {"result": "fail", "error": "Verification code is invalid."}, 400 - - AccountService.delete_account(account, args["reason"], args["code"]) + try: + AccountService.update_deletion_reason(account, args["reason"]) + except ValueError as e: + return {"result": "fail", "error": str(e)}, 400 return {"result": "success"} @@ -288,6 +307,7 @@ api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") api.add_resource(AccountTimezoneApi, "/account/timezone") api.add_resource(AccountPasswordApi, "/account/password") api.add_resource(AccountIntegrateApi, "/account/integrates") -api.add_resource(AccountDeleteVerifyApi, "/account/delete-verify") +api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") +api.add_resource(AccountDeleteApi, "/account/delete") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/libs/helper.py b/api/libs/helper.py index 8273c8eac3..986540e4df 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -8,6 +8,7 @@ import time import uuid from collections.abc import Generator from datetime import datetime +from datetime import timezone as tz from hashlib import sha256 from typing import Any, Optional, Union from zoneinfo import available_timezones @@ -19,6 +20,9 @@ from flask import Response, current_app, stream_with_context from flask_restful import fields from models.account import Account +from api.configs import dify_config +from api.services.errors.account import RateLimitExceededError + def run(script): return subprocess.getstatusoutput("source /root/.bashrc && " + script) @@ -269,3 +273,66 @@ class RateLimiter: redis_client.zadd(key, {current_time: current_time}) redis_client.expire(key, self.time_window * 2) + + +def get_current_datetime(): + return datetime.now(tz.utc).replace(tzinfo=None) + + +class VerificationCodeManager: + @classmethod + def generate_verification_code(cls, account: Account, code_type: str) -> str: + # Check if this key is still in cooldown period + now = int(time.time()) + created_at = cls._get_verification_code_created_at(code_type, account.id) + if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN_MINUTES * 60: + raise RateLimitExceededError() + if created_at is not None: + cls._revoke_verification_code(code_type, account.id) + + verification_code = generate_string(dify_config.VERIFICATION_CODE_LENGTH) + cls._set_verification_code(code_type, account.id, verification_code, dify_config.VERIFICATION_CODE_EXPIRY_MINUTES) + + return verification_code + + @classmethod + def verify_verification_code(cls, account: Account, code_type: str, verification_code: str) -> bool: + key, _ = cls._get_key(code_type, account_id=account.id) + stored_verification_code = redis_client.get(key) + + if stored_verification_code is None: + return False + return stored_verification_code == verification_code + + ### Helper methods ### + @classmethod + def _set_verification_code(cls, code_type: str, account_id: str, verification_code: str, expire_minutes: int) -> None: + key, time_key = cls._get_key(code_type, account_id) + now = int(time.time()) + + redis_client.setex(key, expire_minutes * 60, verification_code) + redis_client.setex(time_key, expire_minutes * 60, now) + + @classmethod + def _get_verification_code(cls, code_type: str, account_id: str) -> Optional[str]: + key, _ = cls._get_key(code_type, account_id) + verification_code = redis_client.get(key) + + return verification_code + + @classmethod + def _get_verification_code_created_at(cls, code_type: str, account_id: str) -> Optional[int]: + _, time_key = cls._get_key(code_type, account_id) + created_at = redis_client.get(time_key) + + return int(created_at) if created_at is not None else None + + @classmethod + def _revoke_verification_code(cls, code_type: str, account_id: str) -> None: + key, time_key = cls._get_key(code_type, account_id) + redis_client.delete(key) + redis_client.delete(time_key) + + @classmethod + def _get_key(cls, code_type: str, account_id: str) -> tuple[str, str]: + return f"verification:{code_type}:{account_id}", f"verification:{code_type}:{account_id}:time" diff --git a/api/models/account.py b/api/models/account.py index 9f0ff7bbe4..9f5ce418f4 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -278,7 +278,7 @@ class AccountDeletionLog(db.Model): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) status = db.Column(Enum(AccountDeletionLogStatus), nullable=False, default=AccountDeletionLogStatus.PENDING) - reason = db.Column(db.Text) + reason = db.Column(db.Text, nullable=True) email = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/schedule/delete_account_task.py b/api/schedule/delete_account_task.py index 1e4ccc2913..51d0b72984 100644 --- a/api/schedule/delete_account_task.py +++ b/api/schedule/delete_account_task.py @@ -7,9 +7,11 @@ from extensions.ext_database import db from sqlalchemy import or_ from sqlalchemy.exc import SQLAlchemyError +from api.libs.helper import get_current_datetime from api.models.account import (Account, AccountDeletionLog, AccountDeletionLogStatus, AccountIntegrate, - TenantAccountJoin, TenantAccountJoinRole) + Tenant, TenantAccountJoin, + TenantAccountJoinRole) from api.models.api_based_extension import APIBasedExtension from api.models.dataset import (AppDatasetJoin, Dataset, DatasetPermission, Document, DocumentSegment) @@ -18,11 +20,15 @@ from api.models.model import (ApiToken, App, AppAnnotationSetting, DatasetRetrieverResource, EndUser, InstalledApp, Message, MessageAgentThought, MessageAnnotation, MessageChain, MessageFeedback, MessageFile, - RecommendedApp) + RecommendedApp, Site, Tag, TagBinding) from api.models.provider import (LoadBalancingModelConfig, Provider, - ProviderModel, ProviderModelSetting) + ProviderModel, ProviderModelSetting, + TenantDefaultModel, + TenantPreferredModelProvider) from api.models.source import (DataSourceApiKeyAuthBinding, DataSourceOauthBinding) +from api.models.tools import (ApiToolProvider, BuiltinToolProvider, + ToolConversationVariables) from api.models.web import PinnedConversation, SavedMessage from api.tasks.mail_account_deletion_done_task import \ send_deletion_success_task @@ -82,6 +88,9 @@ def _delete_app(app: App, account_id): # saved_messages db.session.query(SavedMessage).filter(SavedMessage.app_id == app.id).delete() + # sites + db.session.query(Site).filter(Site.app_id == app.id).delete() + db.session.delete(app) @@ -93,6 +102,8 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin): """ tenant_id, account_id = tenant_account_join.tenant_id, tenant_account_join.account_id + member_ids = db.session.query(TenantAccountJoin.account_id).filter(TenantAccountJoin.tenant_id == tenant_id).all() + # api_based_extensions db.session.query(APIBasedExtension).filter(APIBasedExtension.tenant_id == tenant_id).delete() @@ -105,17 +116,19 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin): db.session.query(DatasetPermission).filter(DatasetPermission.tenant_id == tenant_id).delete() # datasets + dataset_ids = db.session.query(Dataset.id).filter(Dataset.tenant_id == tenant_id).all() db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).delete() + # documents + document_ids = db.session.query(Document.id).filter(Document.tenant_id == tenant_id).all() + db.session.query(Document).filter(Document.tenant_id == tenant_id).delete() + # data_source_api_key_auth_bindings db.session.query(DataSourceApiKeyAuthBinding).filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id).delete() # data_source_oauth_bindings db.session.query(DataSourceOauthBinding).filter(DataSourceOauthBinding.tenant_id == tenant_id).delete() - # documents - db.session.query(Document).filter(Document.tenant_id == tenant_id).delete() - # document_segments db.session.query(DocumentSegment).filter(DocumentSegment.tenant_id == tenant_id).delete() @@ -128,14 +141,38 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin): # provder_model_settings db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).delete() - # skip provider_orders + # skip provider_orders (TODO: confirm this) # providers db.session.query(Provider).filter(Provider.tenant_id == tenant_id).delete() + # tag_bindings + db.session.query(TagBinding).filter(TagBinding.tenant_id == tenant_id).delete() + + # tags + db.session.query(Tag).filter(Tag.tenant_id == tenant_id).delete() + + # tenant_default_models + db.session.query(TenantDefaultModel).filter(TenantDefaultModel.tenant_id == tenant_id).delete() + + # tenant_preferred_model_providers + db.session.query(TenantPreferredModelProvider).filter(TenantPreferredModelProvider.tenant_id == tenant_id).delete() + + # tool_api_providers + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).delete() + + # tool_built_in_providers + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).delete() + + # tool_conversation_variables + db.session.query(ToolConversationVariables).filter(ToolConversationVariables.tenant_id == tenant_id).delete() + # Delete all tenant_account_joins of this tenant db.session.query(TenantAccountJoin).filter(TenantAccountJoin.tenant_id == tenant_id).delete() + # Delete tenant + db.session.query(Tenant).filter(Tenant.id == tenant_id).delete() + def _delete_tenant_as_non_owner(tenant_account_join: TenantAccountJoin): """Actual deletion of tenant as non-owner. Related tables will also be deleted. @@ -182,10 +219,15 @@ def _delete_user(log: AccountDeletionLog, account: Account) -> bool: # delete account db.session.delete(account) + # update log status + log.status = AccountDeletionLogStatus.COMPLETED + log.updated_at = get_current_datetime() + except SQLAlchemyError as e: db.session.rollback() logger.exception(click.style(f"Failed to delete account {log.account_id}, error: {e}", fg="red")) log.status = AccountDeletionLogStatus.FAILED + log.updated_at = get_current_datetime() success = False finally: db.session.commit() @@ -231,6 +273,7 @@ def delete_account_task(): if not account: logger.exception(click.style(f"Account {log.account_id} not found.", fg="red")) log.status = AccountDeletionLogStatus.FAILED + log.updated_at = get_current_datetime() db.session.commit() continue diff --git a/api/services/account_service.py b/api/services/account_service.py index f59f3156b2..8daa7b6291 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -10,7 +10,7 @@ from configs import dify_config from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client -from libs.helper import RateLimiter, TokenManager +from libs.helper import RateLimiter, TokenManager, get_current_datetime from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair @@ -155,7 +155,18 @@ class AccountService: return account @staticmethod - def delete_account(account: Account, reason: str) -> None: + def update_deletion_reason(account: Account, reason: str) -> None: + """Update deletion log reason""" + account_deletion_log = AccountDeletionLog.query.filter_by(account_id=account.id).first() + if not account_deletion_log: + raise Exception("Account deletion log not found.") + + account_deletion_log.reason = reason + account_deletion_log.updated_at = get_current_datetime() + db.session.commit() + + @staticmethod + def delete_account(account: Account) -> None: """Delete account. Actual deletion is done by the background scheduler.""" logging.info(f"Start deletion of account {account.id}.") @@ -163,7 +174,6 @@ class AccountService: account_deletion_log = AccountDeletionLog( account_id=account.id, status=AccountDeletionLogStatus.PENDING, - reason=reason ) db.session.add(account_deletion_log) db.session.commit() diff --git a/api/services/verification_service.py b/api/services/verification_service.py index 027c9a7eff..1188b6260b 100644 --- a/api/services/verification_service.py +++ b/api/services/verification_service.py @@ -1,69 +1,20 @@ - -import time -from typing import Optional - -from api.configs import dify_config -from api.extensions.ext_redis import redis_client -from api.libs.helper import generate_string, generate_text_hash -from api.services.errors.account import RateLimitExceededError +from api.libs.helper import VerificationCodeManager +from api.models.account import Account class VerificationService: + @classmethod - def generate_account_deletion_verification_code(cls, email: str) -> str: - return cls._generate_verification_code( - email=email, - prefix="account_deletion", - expire=dify_config.VERIFICATION_CODE_EXPIRY, - code_length=dify_config.VERIFICATION_CODE_LENGTH + def generate_account_deletion_verification_code(cls, account: Account) -> str: + return VerificationCodeManager.generate_verification_code( + account=account, + code_type="account_deletion", ) @classmethod - def verify_account_deletion_verification_code(cls, email: str, verification_code: str) -> bool: - return cls._verify_verification_code( - email=email, - prefix="account_deletion", - verification_code=verification_code + def verify_account_deletion_verification_code(cls, account: Account, verification_code: str) -> bool: + return VerificationCodeManager.verify_verification_code( + account=account, + code_type="account_deletion", + verification_code=verification_code, ) - - ### Helper methods ### - - @classmethod - def _generate_verification_code(cls, key_name: str, prefix: str, expire: int = 300, code_length: int = 6) -> str: - hashed_key = generate_text_hash(key_name) - key, time_key = cls._get_key(f"{prefix}:{hashed_key}") - now = int(time.time()) - - # Check if there is already a verification code for this key within 1 minute - created_at = redis_client.get(time_key) - if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN: - raise RateLimitExceededError() - - created_at = now - verification_code = generate_string(code_length) - - redis_client.setex(key, expire, verification_code) - redis_client.setex(time_key, expire, created_at) - return verification_code - - @classmethod - def _get_verification_code(cls, prefix: str, key_name: str) -> Optional[str]: - hashed_key = generate_text_hash(key_name) - key, _ = cls._get_key(f"{prefix}:{hashed_key}") - verification_code = redis_client.get(key) - - return verification_code - - @classmethod - def _verify_verification_code(cls, key_name: str, prefix: str, verification_code: str) -> bool: - hashed_key = generate_text_hash(key_name) - key, _ = cls._get_key(f"{prefix}:{hashed_key}") - stored_verification_code = redis_client.get(key) - - if stored_verification_code is None: - return False - return stored_verification_code == verification_code - - @classmethod - def _get_key(cls, key_name: str) -> str: - return f"verification:{key_name}", f"verification:{key_name}:time"