feat: reconstruct verification logic

This commit is contained in:
GareArc 2024-10-03 13:21:15 -04:00
parent 21b8f26cd1
commit e5c9f821ab
7 changed files with 176 additions and 85 deletions

View File

@ -599,9 +599,9 @@ class PositionConfig(BaseSettings):
class VerificationConfig(BaseSettings): class VerificationConfig(BaseSettings):
VERIFICATION_CODE_EXPIRY: PositiveInt = Field( VERIFICATION_CODE_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in seconds for which a verification code remains valid", description="Duration in minutes for which a verification code remains valid",
default=300, default=5,
) )
VERIFICATION_CODE_LENGTH: PositiveInt = Field( VERIFICATION_CODE_LENGTH: PositiveInt = Field(
@ -609,9 +609,9 @@ class VerificationConfig(BaseSettings):
default=6, default=6,
) )
VERIFICATION_CODE_COOLDOWN: PositiveInt = Field( VERIFICATION_CODE_COOLDOWN_MINUTES: PositiveInt = Field(
description="Cooldown time in seconds between verification code generation", description="Cooldown time in minutes between verification code generation",
default=60, default=1,
) )

View File

@ -248,17 +248,20 @@ class AccountDeleteVerifyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def post(self):
account = current_user account = current_user
try: try:
code = VerificationService.generate_account_deletion_verification_code(account.email) code = VerificationService.generate_account_deletion_verification_code(account)
AccountService.send_account_delete_verification_email(account, code) AccountService.send_account_delete_verification_email(account, code)
except RateLimitExceededError: except RateLimitExceededError:
return {"result": "fail", "error": "Rate limit exceeded."}, 429 return {"result": "fail", "error": "Rate limit exceeded."}, 429
return {"result": "success"} return {"result": "success"}
class AccountDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -267,13 +270,29 @@ class AccountDeleteVerifyApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("code", type=str, required=True, location="json")
args = parser.parse_args()
if not VerificationService.verify_account_deletion_verification_code(account, args["code"]):
return {"result": "fail", "error": "Verification code is invalid."}, 400
AccountService.delete_account(account)
return {"result": "success"}
@setup_required
@login_required
@account_initialization_required
def patch(self):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("reason", type=str, required=True, location="json") parser.add_argument("reason", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not VerificationService.verify_account_deletion_verification_code(account.email, args["code"]): try:
return {"result": "fail", "error": "Verification code is invalid."}, 400 AccountService.update_deletion_reason(account, args["reason"])
except ValueError as e:
AccountService.delete_account(account, args["reason"], args["code"]) return {"result": "fail", "error": str(e)}, 400
return {"result": "success"} return {"result": "success"}
@ -288,6 +307,7 @@ api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone") api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password") api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates") api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete-verify") api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
# api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -8,6 +8,7 @@ import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime from datetime import datetime
from datetime import timezone as tz
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional, Union from typing import Any, Optional, Union
from zoneinfo import available_timezones from zoneinfo import available_timezones
@ -19,6 +20,9 @@ from flask import Response, current_app, stream_with_context
from flask_restful import fields from flask_restful import fields
from models.account import Account from models.account import Account
from api.configs import dify_config
from api.services.errors.account import RateLimitExceededError
def run(script): def run(script):
return subprocess.getstatusoutput("source /root/.bashrc && " + script) return subprocess.getstatusoutput("source /root/.bashrc && " + script)
@ -269,3 +273,66 @@ class RateLimiter:
redis_client.zadd(key, {current_time: current_time}) redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2) redis_client.expire(key, self.time_window * 2)
def get_current_datetime():
return datetime.now(tz.utc).replace(tzinfo=None)
class VerificationCodeManager:
@classmethod
def generate_verification_code(cls, account: Account, code_type: str) -> str:
# Check if this key is still in cooldown period
now = int(time.time())
created_at = cls._get_verification_code_created_at(code_type, account.id)
if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN_MINUTES * 60:
raise RateLimitExceededError()
if created_at is not None:
cls._revoke_verification_code(code_type, account.id)
verification_code = generate_string(dify_config.VERIFICATION_CODE_LENGTH)
cls._set_verification_code(code_type, account.id, verification_code, dify_config.VERIFICATION_CODE_EXPIRY_MINUTES)
return verification_code
@classmethod
def verify_verification_code(cls, account: Account, code_type: str, verification_code: str) -> bool:
key, _ = cls._get_key(code_type, account_id=account.id)
stored_verification_code = redis_client.get(key)
if stored_verification_code is None:
return False
return stored_verification_code == verification_code
### Helper methods ###
@classmethod
def _set_verification_code(cls, code_type: str, account_id: str, verification_code: str, expire_minutes: int) -> None:
key, time_key = cls._get_key(code_type, account_id)
now = int(time.time())
redis_client.setex(key, expire_minutes * 60, verification_code)
redis_client.setex(time_key, expire_minutes * 60, now)
@classmethod
def _get_verification_code(cls, code_type: str, account_id: str) -> Optional[str]:
key, _ = cls._get_key(code_type, account_id)
verification_code = redis_client.get(key)
return verification_code
@classmethod
def _get_verification_code_created_at(cls, code_type: str, account_id: str) -> Optional[int]:
_, time_key = cls._get_key(code_type, account_id)
created_at = redis_client.get(time_key)
return int(created_at) if created_at is not None else None
@classmethod
def _revoke_verification_code(cls, code_type: str, account_id: str) -> None:
key, time_key = cls._get_key(code_type, account_id)
redis_client.delete(key)
redis_client.delete(time_key)
@classmethod
def _get_key(cls, code_type: str, account_id: str) -> tuple[str, str]:
return f"verification:{code_type}:{account_id}", f"verification:{code_type}:{account_id}:time"

View File

@ -278,7 +278,7 @@ class AccountDeletionLog(db.Model):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
account_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
status = db.Column(Enum(AccountDeletionLogStatus), nullable=False, default=AccountDeletionLogStatus.PENDING) status = db.Column(Enum(AccountDeletionLogStatus), nullable=False, default=AccountDeletionLogStatus.PENDING)
reason = db.Column(db.Text) reason = db.Column(db.Text, nullable=True)
email = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -7,9 +7,11 @@ from extensions.ext_database import db
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from api.libs.helper import get_current_datetime
from api.models.account import (Account, AccountDeletionLog, from api.models.account import (Account, AccountDeletionLog,
AccountDeletionLogStatus, AccountIntegrate, AccountDeletionLogStatus, AccountIntegrate,
TenantAccountJoin, TenantAccountJoinRole) Tenant, TenantAccountJoin,
TenantAccountJoinRole)
from api.models.api_based_extension import APIBasedExtension from api.models.api_based_extension import APIBasedExtension
from api.models.dataset import (AppDatasetJoin, Dataset, DatasetPermission, from api.models.dataset import (AppDatasetJoin, Dataset, DatasetPermission,
Document, DocumentSegment) Document, DocumentSegment)
@ -18,11 +20,15 @@ from api.models.model import (ApiToken, App, AppAnnotationSetting,
DatasetRetrieverResource, EndUser, InstalledApp, DatasetRetrieverResource, EndUser, InstalledApp,
Message, MessageAgentThought, MessageAnnotation, Message, MessageAgentThought, MessageAnnotation,
MessageChain, MessageFeedback, MessageFile, MessageChain, MessageFeedback, MessageFile,
RecommendedApp) RecommendedApp, Site, Tag, TagBinding)
from api.models.provider import (LoadBalancingModelConfig, Provider, from api.models.provider import (LoadBalancingModelConfig, Provider,
ProviderModel, ProviderModelSetting) ProviderModel, ProviderModelSetting,
TenantDefaultModel,
TenantPreferredModelProvider)
from api.models.source import (DataSourceApiKeyAuthBinding, from api.models.source import (DataSourceApiKeyAuthBinding,
DataSourceOauthBinding) DataSourceOauthBinding)
from api.models.tools import (ApiToolProvider, BuiltinToolProvider,
ToolConversationVariables)
from api.models.web import PinnedConversation, SavedMessage from api.models.web import PinnedConversation, SavedMessage
from api.tasks.mail_account_deletion_done_task import \ from api.tasks.mail_account_deletion_done_task import \
send_deletion_success_task send_deletion_success_task
@ -82,6 +88,9 @@ def _delete_app(app: App, account_id):
# saved_messages # saved_messages
db.session.query(SavedMessage).filter(SavedMessage.app_id == app.id).delete() db.session.query(SavedMessage).filter(SavedMessage.app_id == app.id).delete()
# sites
db.session.query(Site).filter(Site.app_id == app.id).delete()
db.session.delete(app) db.session.delete(app)
@ -93,6 +102,8 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
""" """
tenant_id, account_id = tenant_account_join.tenant_id, tenant_account_join.account_id tenant_id, account_id = tenant_account_join.tenant_id, tenant_account_join.account_id
member_ids = db.session.query(TenantAccountJoin.account_id).filter(TenantAccountJoin.tenant_id == tenant_id).all()
# api_based_extensions # api_based_extensions
db.session.query(APIBasedExtension).filter(APIBasedExtension.tenant_id == tenant_id).delete() db.session.query(APIBasedExtension).filter(APIBasedExtension.tenant_id == tenant_id).delete()
@ -105,17 +116,19 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
db.session.query(DatasetPermission).filter(DatasetPermission.tenant_id == tenant_id).delete() db.session.query(DatasetPermission).filter(DatasetPermission.tenant_id == tenant_id).delete()
# datasets # datasets
dataset_ids = db.session.query(Dataset.id).filter(Dataset.tenant_id == tenant_id).all()
db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).delete() db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).delete()
# documents
document_ids = db.session.query(Document.id).filter(Document.tenant_id == tenant_id).all()
db.session.query(Document).filter(Document.tenant_id == tenant_id).delete()
# data_source_api_key_auth_bindings # data_source_api_key_auth_bindings
db.session.query(DataSourceApiKeyAuthBinding).filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id).delete() db.session.query(DataSourceApiKeyAuthBinding).filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id).delete()
# data_source_oauth_bindings # data_source_oauth_bindings
db.session.query(DataSourceOauthBinding).filter(DataSourceOauthBinding.tenant_id == tenant_id).delete() db.session.query(DataSourceOauthBinding).filter(DataSourceOauthBinding.tenant_id == tenant_id).delete()
# documents
db.session.query(Document).filter(Document.tenant_id == tenant_id).delete()
# document_segments # document_segments
db.session.query(DocumentSegment).filter(DocumentSegment.tenant_id == tenant_id).delete() db.session.query(DocumentSegment).filter(DocumentSegment.tenant_id == tenant_id).delete()
@ -128,14 +141,38 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
# provder_model_settings # provder_model_settings
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).delete() db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).delete()
# skip provider_orders # skip provider_orders (TODO: confirm this)
# providers # providers
db.session.query(Provider).filter(Provider.tenant_id == tenant_id).delete() db.session.query(Provider).filter(Provider.tenant_id == tenant_id).delete()
# tag_bindings
db.session.query(TagBinding).filter(TagBinding.tenant_id == tenant_id).delete()
# tags
db.session.query(Tag).filter(Tag.tenant_id == tenant_id).delete()
# tenant_default_models
db.session.query(TenantDefaultModel).filter(TenantDefaultModel.tenant_id == tenant_id).delete()
# tenant_preferred_model_providers
db.session.query(TenantPreferredModelProvider).filter(TenantPreferredModelProvider.tenant_id == tenant_id).delete()
# tool_api_providers
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).delete()
# tool_built_in_providers
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).delete()
# tool_conversation_variables
db.session.query(ToolConversationVariables).filter(ToolConversationVariables.tenant_id == tenant_id).delete()
# Delete all tenant_account_joins of this tenant # Delete all tenant_account_joins of this tenant
db.session.query(TenantAccountJoin).filter(TenantAccountJoin.tenant_id == tenant_id).delete() db.session.query(TenantAccountJoin).filter(TenantAccountJoin.tenant_id == tenant_id).delete()
# Delete tenant
db.session.query(Tenant).filter(Tenant.id == tenant_id).delete()
def _delete_tenant_as_non_owner(tenant_account_join: TenantAccountJoin): def _delete_tenant_as_non_owner(tenant_account_join: TenantAccountJoin):
"""Actual deletion of tenant as non-owner. Related tables will also be deleted. """Actual deletion of tenant as non-owner. Related tables will also be deleted.
@ -182,10 +219,15 @@ def _delete_user(log: AccountDeletionLog, account: Account) -> bool:
# delete account # delete account
db.session.delete(account) db.session.delete(account)
# update log status
log.status = AccountDeletionLogStatus.COMPLETED
log.updated_at = get_current_datetime()
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.session.rollback() db.session.rollback()
logger.exception(click.style(f"Failed to delete account {log.account_id}, error: {e}", fg="red")) logger.exception(click.style(f"Failed to delete account {log.account_id}, error: {e}", fg="red"))
log.status = AccountDeletionLogStatus.FAILED log.status = AccountDeletionLogStatus.FAILED
log.updated_at = get_current_datetime()
success = False success = False
finally: finally:
db.session.commit() db.session.commit()
@ -231,6 +273,7 @@ def delete_account_task():
if not account: if not account:
logger.exception(click.style(f"Account {log.account_id} not found.", fg="red")) logger.exception(click.style(f"Account {log.account_id} not found.", fg="red"))
log.status = AccountDeletionLogStatus.FAILED log.status = AccountDeletionLogStatus.FAILED
log.updated_at = get_current_datetime()
db.session.commit() db.session.commit()
continue continue

View File

@ -10,7 +10,7 @@ from configs import dify_config
from constants.languages import language_timezone_mapping, languages from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.helper import RateLimiter, TokenManager from libs.helper import RateLimiter, TokenManager, get_current_datetime
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
@ -155,7 +155,18 @@ class AccountService:
return account return account
@staticmethod @staticmethod
def delete_account(account: Account, reason: str) -> None: def update_deletion_reason(account: Account, reason: str) -> None:
"""Update deletion log reason"""
account_deletion_log = AccountDeletionLog.query.filter_by(account_id=account.id).first()
if not account_deletion_log:
raise Exception("Account deletion log not found.")
account_deletion_log.reason = reason
account_deletion_log.updated_at = get_current_datetime()
db.session.commit()
@staticmethod
def delete_account(account: Account) -> None:
"""Delete account. Actual deletion is done by the background scheduler.""" """Delete account. Actual deletion is done by the background scheduler."""
logging.info(f"Start deletion of account {account.id}.") logging.info(f"Start deletion of account {account.id}.")
@ -163,7 +174,6 @@ class AccountService:
account_deletion_log = AccountDeletionLog( account_deletion_log = AccountDeletionLog(
account_id=account.id, account_id=account.id,
status=AccountDeletionLogStatus.PENDING, status=AccountDeletionLogStatus.PENDING,
reason=reason
) )
db.session.add(account_deletion_log) db.session.add(account_deletion_log)
db.session.commit() db.session.commit()

View File

@ -1,69 +1,20 @@
from api.libs.helper import VerificationCodeManager
import time from api.models.account import Account
from typing import Optional
from api.configs import dify_config
from api.extensions.ext_redis import redis_client
from api.libs.helper import generate_string, generate_text_hash
from api.services.errors.account import RateLimitExceededError
class VerificationService: class VerificationService:
@classmethod @classmethod
def generate_account_deletion_verification_code(cls, email: str) -> str: def generate_account_deletion_verification_code(cls, account: Account) -> str:
return cls._generate_verification_code( return VerificationCodeManager.generate_verification_code(
email=email, account=account,
prefix="account_deletion", code_type="account_deletion",
expire=dify_config.VERIFICATION_CODE_EXPIRY,
code_length=dify_config.VERIFICATION_CODE_LENGTH
) )
@classmethod @classmethod
def verify_account_deletion_verification_code(cls, email: str, verification_code: str) -> bool: def verify_account_deletion_verification_code(cls, account: Account, verification_code: str) -> bool:
return cls._verify_verification_code( return VerificationCodeManager.verify_verification_code(
email=email, account=account,
prefix="account_deletion", code_type="account_deletion",
verification_code=verification_code verification_code=verification_code,
) )
### Helper methods ###
@classmethod
def _generate_verification_code(cls, key_name: str, prefix: str, expire: int = 300, code_length: int = 6) -> str:
hashed_key = generate_text_hash(key_name)
key, time_key = cls._get_key(f"{prefix}:{hashed_key}")
now = int(time.time())
# Check if there is already a verification code for this key within 1 minute
created_at = redis_client.get(time_key)
if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN:
raise RateLimitExceededError()
created_at = now
verification_code = generate_string(code_length)
redis_client.setex(key, expire, verification_code)
redis_client.setex(time_key, expire, created_at)
return verification_code
@classmethod
def _get_verification_code(cls, prefix: str, key_name: str) -> Optional[str]:
hashed_key = generate_text_hash(key_name)
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
verification_code = redis_client.get(key)
return verification_code
@classmethod
def _verify_verification_code(cls, key_name: str, prefix: str, verification_code: str) -> bool:
hashed_key = generate_text_hash(key_name)
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
stored_verification_code = redis_client.get(key)
if stored_verification_code is None:
return False
return stored_verification_code == verification_code
@classmethod
def _get_key(cls, key_name: str) -> str:
return f"verification:{key_name}", f"verification:{key_name}:time"