Feat/implement-refresh-tokens (#9233)

This commit is contained in:
-LAN- 2024-10-12 23:46:30 +08:00 committed by GitHub
parent dbfbc56de7
commit f73751843f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 123 additions and 38 deletions

View File

@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected # The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300 FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id") user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
if logged_in_account: if logged_in_account:
contexts.tenant_id.set(logged_in_account.current_tenant_id) contexts.tenant_id.set(logged_in_account.current_tenant_id)
return logged_in_account return logged_in_account

View File

@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
) )
class OAuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """
Configuration for OAuth authentication Configuration for authentication and OAuth
""" """
OAUTH_REDIRECT_PATH: str = Field( OAUTH_REDIRECT_PATH: str = Field(
@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
) )
GITHUB_CLIENT_ID: Optional[str] = Field( GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret", description="GitHub OAuth client ID",
default=None, default=None,
) )
@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
default=None, default=None,
) )
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
description="Expiration time for access tokens in minutes",
default=60,
)
class ModerationConfig(BaseSettings): class ModerationConfig(BaseSettings):
""" """
@ -607,6 +612,7 @@ class PositionConfig(BaseSettings):
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
DataSetConfig, DataSetConfig,
@ -621,14 +627,13 @@ class FeatureConfig(
MailConfig, MailConfig,
ModelLoadBalanceConfig, ModelLoadBalanceConfig,
ModerationConfig, ModerationConfig,
OAuthConfig, PositionConfig,
RagEtlConfig, RagEtlConfig,
SecurityConfig, SecurityConfig,
ToolConfig, ToolConfig,
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,
WorkspaceConfig, WorkspaceConfig,
PositionConfig,
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,

View File

@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from libs.helper import email, get_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models.account import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
@ -40,17 +40,16 @@ class LoginApi(Resource):
"data": "workspace not found, please contact system admin to invite you to join in a workspace", "data": "workspace not found, please contact system admin to invite you to join in a workspace",
} }
token = AccountService.login(account, ip_address=get_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token} return {"result": "success", "data": token_pair.model_dump()}
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1] AccountService.logout(account=account)
AccountService.logout(account=account, token=token)
flask_login.logout_user() flask_login.logout_user()
return {"result": "success"} return {"result": "success"}
@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
return {"result": "success"} return {"result": "success"}
class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()
try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e:
return {"result": "fail", "data": str(e)}, 401
api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout") api.add_resource(LogoutApi, "/logout")
api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@ -9,7 +9,7 @@ from flask_restful import Resource
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import get_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
@ -81,9 +81,14 @@ class OAuthCallback(Resource):
TenantService.create_owner_tenant_if_not_exist(account) TenantService.create_owner_tenant_if_not_exist(account)
token = AccountService.login(account, ip_address=get_remote_ip(request)) token_pair = AccountService.login(
account=account,
ip_address=extract_remote_ip(request),
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -46,7 +46,7 @@ class SetupApi(Resource):
# setup # setup
RegisterService.setup( RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
) )
return {"result": "success"}, 201 return {"result": "success"}, 201

View File

@ -162,7 +162,7 @@ def generate_string(n):
return result return result
def get_remote_ip(request) -> str: def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"): if request.headers.get("CF-Connecting-IP"):
return request.headers.get("Cf-Connecting-Ip") return request.headers.get("Cf-Connecting-Ip")
elif request.headers.getlist("X-Forwarded-For"): elif request.headers.getlist("X-Forwarded-For"):

View File

@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
@ -49,9 +50,39 @@ from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task
class TokenPair(BaseModel):
access_token: str
refresh_token: str
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
class AccountService: class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"
@staticmethod
def _get_account_refresh_token_key(account_id: str) -> str:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@staticmethod @staticmethod
def load_user(user_id: str) -> None | Account: def load_user(user_id: str) -> None | Account:
account = Account.query.filter_by(id=user_id).first() account = Account.query.filter_by(id=user_id).first()
@ -61,9 +92,7 @@ class AccountService:
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise Unauthorized("Account is banned or closed.") raise Unauthorized("Account is banned or closed.")
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
account_id=account.id, current=True
).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.current_tenant_id = current_tenant.tenant_id
else: else:
@ -84,10 +113,12 @@ class AccountService:
return account return account
@staticmethod @staticmethod
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
payload = { payload = {
"user_id": account.id, "user_id": account.id,
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "exp": exp,
"iss": dify_config.EDITION, "iss": dify_config.EDITION,
"sub": "Console API Passport", "sub": "Console API Passport",
} }
@ -213,7 +244,7 @@ class AccountService:
return account return account
@staticmethod @staticmethod
def update_last_login(account: Account, *, ip_address: str) -> None: def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip""" """Update last login time and ip"""
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
account.last_login_ip = ip_address account.last_login_ip = ip_address
@ -221,22 +252,45 @@ class AccountService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def login(account: Account, *, ip_address: Optional[str] = None): def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
if ip_address: if ip_address:
AccountService.update_last_login(account, ip_address=ip_address) AccountService.update_login_info(account=account, ip_address=ip_address)
exp = timedelta(days=30)
token = AccountService.get_account_jwt_token(account, exp=exp) access_token = AccountService.get_account_jwt_token(account=account)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) refresh_token = _generate_refresh_token()
return token
AccountService._store_refresh_token(refresh_token, account.id)
return TokenPair(access_token=access_token, refresh_token=refresh_token)
@staticmethod @staticmethod
def logout(*, account: Account, token: str): def logout(*, account: Account) -> None:
redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@staticmethod @staticmethod
def load_logged_in_account(*, account_id: str, token: str): def refresh_token(refresh_token: str) -> TokenPair:
if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): # Verify the refresh token
return None account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id:
raise ValueError("Invalid refresh token")
account = AccountService.load_user(account_id.decode("utf-8"))
if not account:
raise ValueError("Invalid account")
# Generate new access token and refresh token
new_access_token = AccountService.get_account_jwt_token(account)
new_refresh_token = _generate_refresh_token()
AccountService._delete_refresh_token(refresh_token, account.id)
AccountService._store_refresh_token(new_refresh_token, account.id)
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
@staticmethod
def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id) return AccountService.load_user(account_id)
@classmethod @classmethod
@ -258,10 +312,6 @@ class AccountService:
return TokenManager.get_token_data(token, "reset_password") return TokenManager.get_token_data(token, "reset_password")
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"
class TenantService: class TenantService:
@staticmethod @staticmethod
def create_tenant(name: str) -> Tenant: def create_tenant(name: str) -> Tenant:
@ -698,3 +748,8 @@ class RegisterService:
invitation = json.loads(data) invitation = json.loads(data)
return invitation return invitation
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
return token

View File

@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
# The default value is 300 seconds. # The default value is 300 seconds.
FILES_ACCESS_TIMEOUT=300 FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0

View File

@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env
REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-}
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false} BROKER_USE_SSL: ${BROKER_USE_SSL:-false}