feat: add api for account deletion

This commit is contained in:
GareArc 2024-09-29 13:20:44 -04:00
parent 4669eb24be
commit 21b8f26cd1
10 changed files with 549 additions and 37 deletions

View File

@ -1,9 +1,9 @@
from typing import Annotated, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
from pydantic import (AliasChoices, Field, HttpUrl, NegativeInt,
NonNegativeInt, PositiveInt, computed_field)
from pydantic_settings import BaseSettings
class SecurityConfig(BaseSettings):
@ -598,6 +598,23 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class VerificationConfig(BaseSettings):
VERIFICATION_CODE_EXPIRY: PositiveInt = Field(
description="Duration in seconds for which a verification code remains valid",
default=300,
)
VERIFICATION_CODE_LENGTH: PositiveInt = Field(
description="Length of the verification code",
default=6,
)
VERIFICATION_CODE_COOLDOWN: PositiveInt = Field(
description="Cooldown time in seconds between verification code generation",
default=60,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -623,6 +640,7 @@ class FeatureConfig(
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
VerificationConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@ -1,28 +1,29 @@
import datetime
import pytz
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidInvitationCodeError,
RepeatPasswordNotMatchError,
)
from controllers.console.workspace.error import (AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidInvitationCodeError,
RepeatPasswordNotMatchError)
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.member_fields import account_fields
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
from libs.helper import TimestampField, timezone
from libs.login import login_required
from models.account import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
from services.errors.account import \
CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
from services.errors.account import RateLimitExceededError
from api.services.verification_service import VerificationService
class AccountInitApi(Resource):
@ -242,6 +243,41 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data}
class AccountDeleteVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
account = current_user
try:
code = VerificationService.generate_account_deletion_verification_code(account.email)
AccountService.send_account_delete_verification_email(account, code)
except RateLimitExceededError:
return {"result": "fail", "error": "Rate limit exceeded."}, 429
return {"result": "success"}
@setup_required
@login_required
@account_initialization_required
def post(self):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("reason", type=str, required=True, location="json")
args = parser.parse_args()
if not VerificationService.verify_account_deletion_verification_code(account.email, args["code"]):
return {"result": "fail", "error": "Verification code is invalid."}, 400
AccountService.delete_account(account, args["reason"], args["code"])
return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
@ -252,5 +288,6 @@ api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete-verify")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -53,6 +53,7 @@ def init_app(app: Flask) -> Celery:
imports = [
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
"schedule.delete_account_task",
]
day = app.config.get("CELERY_BEAT_SCHEDULER_TIME")
beat_schedule = {
@ -64,6 +65,10 @@ def init_app(app: Flask) -> Celery:
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
"delete_account_task": {
"task": "schedule.delete_account_task.delete_account_task",
"schedule": timedelta(hours='*/1'), # pull every 1 hour
},
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)

View File

@ -12,12 +12,11 @@ from hashlib import sha256
from typing import Any, Optional, Union
from zoneinfo import available_timezones
from flask import Response, current_app, stream_with_context
from flask_restful import fields
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_redis import redis_client
from flask import Response, current_app, stream_with_context
from flask_restful import fields
from models.account import Account

View File

@ -1,9 +1,9 @@
import enum
import json
from flask_login import UserMixin
from extensions.ext_database import db
from flask_login import UserMixin
from sqlalchemy import Enum
from .types import StringUUID
@ -259,3 +259,27 @@ class InvitationCode(db.Model):
used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class AccountDeletionLogStatus(str, enum.Enum):
PENDING = "pending"
FAILED = "failed"
COMPLETED = "completed"
class AccountDeletionLog(db.Model):
__tablename__ = "account_deletion_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_deletion_log_pkey"),
db.Index("account_deletion_logs_account_id_idx", "account_id"),
db.Index("account_deletion_logs_status_idx", "status"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
account_id = db.Column(StringUUID, nullable=False)
status = db.Column(Enum(AccountDeletionLogStatus), nullable=False, default=AccountDeletionLogStatus.PENDING)
reason = db.Column(db.Text)
email = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -0,0 +1,244 @@
import logging
import time
import app
import click
from extensions.ext_database import db
from sqlalchemy import or_
from sqlalchemy.exc import SQLAlchemyError
from api.models.account import (Account, AccountDeletionLog,
AccountDeletionLogStatus, AccountIntegrate,
TenantAccountJoin, TenantAccountJoinRole)
from api.models.api_based_extension import APIBasedExtension
from api.models.dataset import (AppDatasetJoin, Dataset, DatasetPermission,
Document, DocumentSegment)
from api.models.model import (ApiToken, App, AppAnnotationSetting,
AppModelConfig, Conversation,
DatasetRetrieverResource, EndUser, InstalledApp,
Message, MessageAgentThought, MessageAnnotation,
MessageChain, MessageFeedback, MessageFile,
RecommendedApp)
from api.models.provider import (LoadBalancingModelConfig, Provider,
ProviderModel, ProviderModelSetting)
from api.models.source import (DataSourceApiKeyAuthBinding,
DataSourceOauthBinding)
from api.models.web import PinnedConversation, SavedMessage
from api.tasks.mail_account_deletion_done_task import \
send_deletion_success_task
logger = logging.getLogger(__name__)
def _delete_app(app: App, account_id):
"""Actual deletion of app and related tables.
Args:
app: App object
account_id: Account ID
"""
# api_tokens
db.session.query(ApiToken).filter(ApiToken.app_id == app.id).delete()
# app_annotation_settings
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).delete()
# app_dataset_joins
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).delete()
# app_model_configs
db.session.query(AppModelConfig).filter(AppModelConfig.app_id == app.id).delete()
# conversations
db.session.query(Conversation).filter(Conversation.app_id == app.id).delete()
# end_users
db.session.query(EndUser).filter(EndUser.app_id == app.id).delete()
# installed_apps
db.session.query(InstalledApp).filter(InstalledApp.app_id == app.id).delete()
### messages ###
message_ids = db.session.query(Message.id).filter(Message.app_id == app.id).all()
# message_agent_thoughts
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id.in_(message_ids)).delete()
# message_chains
db.session.query(MessageChain).filter(MessageChain.message_id.in_(message_ids)).delete()
# message_files
db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_ids)).delete()
# message_feedbacks
db.session.query(MessageFeedback).filter(MessageFeedback.message_id.in_(message_ids)).delete()
# message_annotations
db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).delete()
# pinned_conversations
db.session.query(PinnedConversation).filter(PinnedConversation.app_id == app.id).delete()
# recommended_apps
db.session.query(RecommendedApp).filter(RecommendedApp.app_id == app.id).delete()
# saved_messages
db.session.query(SavedMessage).filter(SavedMessage.app_id == app.id).delete()
db.session.delete(app)
def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
"""Actual deletion of tenant as owner. Related tables will also be deleted.
Args:
tenant_account_join (TenantAccountJoin): _description_
"""
tenant_id, account_id = tenant_account_join.tenant_id, tenant_account_join.account_id
# api_based_extensions
db.session.query(APIBasedExtension).filter(APIBasedExtension.tenant_id == tenant_id).delete()
# delete all apps of this tenant
apps = db.session.query(App).filter(App.tenant_id == tenant_id).all()
for app in apps:
_delete_app(app, account_id)
# dataset_permissions
db.session.query(DatasetPermission).filter(DatasetPermission.tenant_id == tenant_id).delete()
# datasets
db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).delete()
# data_source_api_key_auth_bindings
db.session.query(DataSourceApiKeyAuthBinding).filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id).delete()
# data_source_oauth_bindings
db.session.query(DataSourceOauthBinding).filter(DataSourceOauthBinding.tenant_id == tenant_id).delete()
# documents
db.session.query(Document).filter(Document.tenant_id == tenant_id).delete()
# document_segments
db.session.query(DocumentSegment).filter(DocumentSegment.tenant_id == tenant_id).delete()
# load_balancing_model_configs
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).delete()
# provider_models
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant_id).delete()
# provder_model_settings
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).delete()
# skip provider_orders
# providers
db.session.query(Provider).filter(Provider.tenant_id == tenant_id).delete()
# Delete all tenant_account_joins of this tenant
db.session.query(TenantAccountJoin).filter(TenantAccountJoin.tenant_id == tenant_id).delete()
def _delete_tenant_as_non_owner(tenant_account_join: TenantAccountJoin):
"""Actual deletion of tenant as non-owner. Related tables will also be deleted.
Args:
tenant_account_join (TenantAccountJoin): _description_
"""
db.session.delete(tenant_account_join)
def _delete_user(log: AccountDeletionLog, account: Account) -> bool:
"""Actual deletion of user account.
Args:
log (AccountDeletionLog): Account deletion log object
Returns:
bool: True if deletion is successful, False otherwise
"""
success = True
account_id = log.account_id
# Wrap in transaction
try:
db.session.begin()
# find all tenants this account belongs to
tenant_account_joins = db.session.query(TenantAccountJoin).filter(TenantAccountJoin.account_id == account_id).all()
# process all tenants
for tenant_account_join in tenant_account_joins:
if tenant_account_join.role == TenantAccountJoinRole.OWNER.value:
_delete_tenant_as_owner(tenant_account_join)
else:
_delete_tenant_as_non_owner(tenant_account_join)
# account_integrates
db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account_id).delete()
# dataset_retriever_resources
db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.created_by == account_id).delete()
# delete account
db.session.delete(account)
except SQLAlchemyError as e:
db.session.rollback()
logger.exception(click.style(f"Failed to delete account {log.account_id}, error: {e}", fg="red"))
log.status = AccountDeletionLogStatus.FAILED
success = False
finally:
db.session.commit()
db.session.close()
return success
@app.celery.task(queue="dataset")
def delete_account_task():
logger.info(click.style("Start delete account task.", fg="green"))
start_at = time.perf_counter()
# Query avaliable deletion tasks from database
queue_size = (
db.session.query(AccountDeletionLog)
.filter(or_(AccountDeletionLog.status == AccountDeletionLogStatus.PENDING,
AccountDeletionLog.status == AccountDeletionLogStatus.FAILED)) # retry failed tasks
.count()
)
logger.info(f"Found {queue_size} delete account tasks in queue.")
# execute deletion in batch
batch_size = 50
n_batches = (queue_size + batch_size - 1) // batch_size
for i in range(n_batches):
offset = i * batch_size
limit = min(queue_size - offset, batch_size)
logger.info(f"Start delete account task batch {i+1}/{n_batches}.")
delete_logs = (
db.session.query(AccountDeletionLog)
.filter(or_(AccountDeletionLog.status == AccountDeletionLogStatus.PENDING,
AccountDeletionLog.status == AccountDeletionLogStatus.FAILED))
.order_by(AccountDeletionLog.created_at.desc())
.offset(offset)
.limit(limit)
)
if delete_logs:
for log in delete_logs:
account = db.session.query(Account).filter(Account.id == log.account_id).first()
if not account:
logger.exception(click.style(f"Account {log.account_id} not found.", fg="red"))
log.status = AccountDeletionLogStatus.FAILED
db.session.commit()
continue
# Delete and notify
if (_delete_user(log)):
send_deletion_success_task.delay(account.interface_language, account.email)
end_at = time.perf_counter()
logger.info(click.style("Delete account tasks of size {} completed with latency: {}".format(len(delete_logs), end_at - start_at), fg="green"))
db.session.remove()

View File

@ -6,9 +6,6 @@ from datetime import datetime, timedelta, timezone
from hashlib import sha256
from typing import Any, Optional
from sqlalchemy import func
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
@ -19,23 +16,24 @@ from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
from models.account import *
from models.model import DifySetup
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotLinkTenantError,
AccountRegisterError,
CannotOperateSelfError,
CurrentPasswordIncorrectError,
InvalidActionError,
LinkAccountIntegrateError,
MemberNotInTenantError,
NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError,
TenantNotFoundError,
)
from services.errors.account import (AccountAlreadyInTenantError,
AccountLoginError,
AccountNotLinkTenantError,
AccountRegisterError,
CannotOperateSelfError,
CurrentPasswordIncorrectError,
InvalidActionError,
LinkAccountIntegrateError,
MemberNotInTenantError, NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError,
TenantNotFoundError)
from sqlalchemy import func
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
from werkzeug.exceptions import Unauthorized
from api.tasks import mail_account_deletion_verify_task
class AccountService:
@ -156,6 +154,20 @@ class AccountService:
db.session.commit()
return account
@staticmethod
def delete_account(account: Account, reason: str) -> None:
"""Delete account. Actual deletion is done by the background scheduler."""
logging.info(f"Start deletion of account {account.id}.")
# add deletion log, set status to pending
account_deletion_log = AccountDeletionLog(
account_id=account.id,
status=AccountDeletionLogStatus.PENDING,
reason=reason
)
db.session.add(account_deletion_log)
db.session.commit()
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate"""
@ -246,6 +258,11 @@ class AccountService:
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def send_account_delete_verification_email(cls, account: Account, code: str):
language, email = account.interface_language, account.email
mail_account_deletion_verify_task.delay(language=language, to=email, code=code)
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"

View File

@ -0,0 +1,69 @@
import time
from typing import Optional
from api.configs import dify_config
from api.extensions.ext_redis import redis_client
from api.libs.helper import generate_string, generate_text_hash
from api.services.errors.account import RateLimitExceededError
class VerificationService:
@classmethod
def generate_account_deletion_verification_code(cls, email: str) -> str:
return cls._generate_verification_code(
email=email,
prefix="account_deletion",
expire=dify_config.VERIFICATION_CODE_EXPIRY,
code_length=dify_config.VERIFICATION_CODE_LENGTH
)
@classmethod
def verify_account_deletion_verification_code(cls, email: str, verification_code: str) -> bool:
return cls._verify_verification_code(
email=email,
prefix="account_deletion",
verification_code=verification_code
)
### Helper methods ###
@classmethod
def _generate_verification_code(cls, key_name: str, prefix: str, expire: int = 300, code_length: int = 6) -> str:
hashed_key = generate_text_hash(key_name)
key, time_key = cls._get_key(f"{prefix}:{hashed_key}")
now = int(time.time())
# Check if there is already a verification code for this key within 1 minute
created_at = redis_client.get(time_key)
if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN:
raise RateLimitExceededError()
created_at = now
verification_code = generate_string(code_length)
redis_client.setex(key, expire, verification_code)
redis_client.setex(time_key, expire, created_at)
return verification_code
@classmethod
def _get_verification_code(cls, prefix: str, key_name: str) -> Optional[str]:
hashed_key = generate_text_hash(key_name)
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
verification_code = redis_client.get(key)
return verification_code
@classmethod
def _verify_verification_code(cls, key_name: str, prefix: str, verification_code: str) -> bool:
hashed_key = generate_text_hash(key_name)
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
stored_verification_code = redis_client.get(key)
if stored_verification_code is None:
return False
return stored_verification_code == verification_code
@classmethod
def _get_key(cls, key_name: str) -> str:
return f"verification:{key_name}", f"verification:{key_name}:time"

View File

@ -0,0 +1,48 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from flask import render_template
@shared_task(queue="mail")
def send_deletion_success_task(language, to):
"""Send email to user regarding account deletion.
Args:
log (AccountDeletionLog): Account deletion log object
"""
if not mail.is_inited():
return
logging.info(
click.style(f"Start send account deletion success email to {to}", fg="green")
)
start_at = time.perf_counter()
try:
if language == "zh-Hans":
html_content = render_template(
"delete_account_mail_template_zh-CN.html",
to=to,
# TODO: Add more template variables
)
mail.send(to=to, subject="Dify 账户删除成功", html=html_content)
else:
html_content = render_template(
"delete_account_mail_template_en-US.html",
to=to,
# TODO: Add more template variables
)
mail.send(to=to, subject="Dify Account Deleted", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style(
"Send account deletion success email to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
)
except Exception:
logging.exception("Send account deletion success email to {} failed".format(to))

View File

@ -0,0 +1,51 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from flask import render_template
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_account_deletion_verification_code(language, to, code):
"""Send email to user regarding account deletion verification code.
Args:
to (str): Recipient email address
code (str): Verification code
"""
if not mail.is_inited():
return
logging.info(
click.style(f"Start send account deletion verification code email to {to}", fg="green")
)
start_at = time.perf_counter()
try:
if language == "zh-Hans":
html_content = render_template(
"delete_account_verification_code_mail_template_zh-CN.html",
to=to,
code=code
)
mail.send(to=to, subject="Dify 删除账户验证码", html=html_content)
else:
html_content = render_template(
"delete_account_verification_code_mail_template_en-US.html",
to=to,
code=code
)
mail.send(to=to, subject="Dify Account Deletion Verification Code", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style(
"Send account deletion verification code email to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
)
except Exception:
logging.exception("Send account deletion verification code email to {} failed".format(to))