mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Feat/implement-refresh-tokens (#9233)
This commit is contained in:
parent
dbfbc56de7
commit
f73751843f
|
@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
|
||||||
# The time in seconds after the signature is rejected
|
# The time in seconds after the signature is rejected
|
||||||
FILES_ACCESS_TIMEOUT=300
|
FILES_ACCESS_TIMEOUT=300
|
||||||
|
|
||||||
|
# Access token expiration time in minutes
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
# celery configuration
|
# celery configuration
|
||||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||||
|
|
||||||
|
|
|
@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
|
||||||
decoded = PassportService().verify(auth_token)
|
decoded = PassportService().verify(auth_token)
|
||||||
user_id = decoded.get("user_id")
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
if logged_in_account:
|
if logged_in_account:
|
||||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
|
|
|
@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuthConfig(BaseSettings):
|
class AuthConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for OAuth authentication
|
Configuration for authentication and OAuth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OAUTH_REDIRECT_PATH: str = Field(
|
OAUTH_REDIRECT_PATH: str = Field(
|
||||||
|
@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||||
description="GitHub OAuth client secret",
|
description="GitHub OAuth client ID",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
|
||||||
|
description="Expiration time for access tokens in minutes",
|
||||||
|
default=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModerationConfig(BaseSettings):
|
class ModerationConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
@ -607,6 +612,7 @@ class PositionConfig(BaseSettings):
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
|
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||||
BillingConfig,
|
BillingConfig,
|
||||||
CodeExecutionSandboxConfig,
|
CodeExecutionSandboxConfig,
|
||||||
DataSetConfig,
|
DataSetConfig,
|
||||||
|
@ -621,14 +627,13 @@ class FeatureConfig(
|
||||||
MailConfig,
|
MailConfig,
|
||||||
ModelLoadBalanceConfig,
|
ModelLoadBalanceConfig,
|
||||||
ModerationConfig,
|
ModerationConfig,
|
||||||
OAuthConfig,
|
PositionConfig,
|
||||||
RagEtlConfig,
|
RagEtlConfig,
|
||||||
SecurityConfig,
|
SecurityConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
UpdateConfig,
|
UpdateConfig,
|
||||||
WorkflowConfig,
|
WorkflowConfig,
|
||||||
WorkspaceConfig,
|
WorkspaceConfig,
|
||||||
PositionConfig,
|
|
||||||
# hosted services config
|
# hosted services config
|
||||||
HostedServiceConfig,
|
HostedServiceConfig,
|
||||||
CeleryBeatConfig,
|
CeleryBeatConfig,
|
||||||
|
|
|
@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from libs.helper import email, get_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
|
@ -40,17 +40,16 @@ class LoginApi(Resource):
|
||||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token_pair.model_dump()}
|
||||||
|
|
||||||
|
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def get(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
account = cast(Account, flask_login.current_user)
|
||||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
AccountService.logout(account=account)
|
||||||
AccountService.logout(account=account, token=token)
|
|
||||||
flask_login.logout_user()
|
flask_login.logout_user()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenApi(Resource):
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||||
|
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||||
|
except Exception as e:
|
||||||
|
return {"result": "fail", "data": str(e)}, 401
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(LoginApi, "/login")
|
api.add_resource(LoginApi, "/login")
|
||||||
api.add_resource(LogoutApi, "/logout")
|
api.add_resource(LogoutApi, "/logout")
|
||||||
|
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||||
|
|
|
@ -9,7 +9,7 @@ from flask_restful import Resource
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import get_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models.account import Account, AccountStatus
|
from models.account import Account, AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
|
@ -81,9 +81,14 @@ class OAuthCallback(Resource):
|
||||||
|
|
||||||
TenantService.create_owner_tenant_if_not_exist(account)
|
TenantService.create_owner_tenant_if_not_exist(account)
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token_pair = AccountService.login(
|
||||||
|
account=account,
|
||||||
|
ip_address=extract_remote_ip(request),
|
||||||
|
)
|
||||||
|
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
return redirect(
|
||||||
|
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||||
|
|
|
@ -4,7 +4,7 @@ from flask import request
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import StrLen, email, get_remote_ip
|
from libs.helper import StrLen, email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
@ -46,7 +46,7 @@ class SetupApi(Resource):
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
|
@ -162,7 +162,7 @@ def generate_string(n):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_remote_ip(request) -> str:
|
def extract_remote_ip(request) -> str:
|
||||||
if request.headers.get("CF-Connecting-IP"):
|
if request.headers.get("CF-Connecting-IP"):
|
||||||
return request.headers.get("Cf-Connecting-Ip")
|
return request.headers.get("Cf-Connecting-Ip")
|
||||||
elif request.headers.getlist("X-Forwarded-For"):
|
elif request.headers.getlist("X-Forwarded-For"):
|
||||||
|
|
|
@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
@ -49,9 +50,39 @@ from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||||
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPair(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||||
|
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||||
|
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
|
||||||
|
|
||||||
|
|
||||||
class AccountService:
|
class AccountService:
|
||||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
|
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_refresh_token_key(refresh_token: str) -> str:
|
||||||
|
return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_account_refresh_token_key(account_id: str) -> str:
|
||||||
|
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||||
|
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
|
||||||
|
redis_client.setex(
|
||||||
|
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||||
|
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||||
|
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_user(user_id: str) -> None | Account:
|
def load_user(user_id: str) -> None | Account:
|
||||||
account = Account.query.filter_by(id=user_id).first()
|
account = Account.query.filter_by(id=user_id).first()
|
||||||
|
@ -61,9 +92,7 @@ class AccountService:
|
||||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||||
raise Unauthorized("Account is banned or closed.")
|
raise Unauthorized("Account is banned or closed.")
|
||||||
|
|
||||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
|
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||||
account_id=account.id, current=True
|
|
||||||
).first()
|
|
||||||
if current_tenant:
|
if current_tenant:
|
||||||
account.current_tenant_id = current_tenant.tenant_id
|
account.current_tenant_id = current_tenant.tenant_id
|
||||||
else:
|
else:
|
||||||
|
@ -84,10 +113,12 @@ class AccountService:
|
||||||
return account
|
return account
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
|
def get_account_jwt_token(account: Account) -> str:
|
||||||
|
exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
exp = int(exp_dt.timestamp())
|
||||||
payload = {
|
payload = {
|
||||||
"user_id": account.id,
|
"user_id": account.id,
|
||||||
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
|
"exp": exp,
|
||||||
"iss": dify_config.EDITION,
|
"iss": dify_config.EDITION,
|
||||||
"sub": "Console API Passport",
|
"sub": "Console API Passport",
|
||||||
}
|
}
|
||||||
|
@ -213,7 +244,7 @@ class AccountService:
|
||||||
return account
|
return account
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_last_login(account: Account, *, ip_address: str) -> None:
|
def update_login_info(account: Account, *, ip_address: str) -> None:
|
||||||
"""Update last login time and ip"""
|
"""Update last login time and ip"""
|
||||||
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
account.last_login_ip = ip_address
|
account.last_login_ip = ip_address
|
||||||
|
@ -221,22 +252,45 @@ class AccountService:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def login(account: Account, *, ip_address: Optional[str] = None):
|
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
|
||||||
if ip_address:
|
if ip_address:
|
||||||
AccountService.update_last_login(account, ip_address=ip_address)
|
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||||
exp = timedelta(days=30)
|
|
||||||
token = AccountService.get_account_jwt_token(account, exp=exp)
|
access_token = AccountService.get_account_jwt_token(account=account)
|
||||||
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
|
refresh_token = _generate_refresh_token()
|
||||||
return token
|
|
||||||
|
AccountService._store_refresh_token(refresh_token, account.id)
|
||||||
|
|
||||||
|
return TokenPair(access_token=access_token, refresh_token=refresh_token)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logout(*, account: Account, token: str):
|
def logout(*, account: Account) -> None:
|
||||||
redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
|
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
||||||
|
if refresh_token:
|
||||||
|
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_logged_in_account(*, account_id: str, token: str):
|
def refresh_token(refresh_token: str) -> TokenPair:
|
||||||
if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
|
# Verify the refresh token
|
||||||
return None
|
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
||||||
|
if not account_id:
|
||||||
|
raise ValueError("Invalid refresh token")
|
||||||
|
|
||||||
|
account = AccountService.load_user(account_id.decode("utf-8"))
|
||||||
|
if not account:
|
||||||
|
raise ValueError("Invalid account")
|
||||||
|
|
||||||
|
# Generate new access token and refresh token
|
||||||
|
new_access_token = AccountService.get_account_jwt_token(account)
|
||||||
|
new_refresh_token = _generate_refresh_token()
|
||||||
|
|
||||||
|
AccountService._delete_refresh_token(refresh_token, account.id)
|
||||||
|
AccountService._store_refresh_token(new_refresh_token, account.id)
|
||||||
|
|
||||||
|
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_logged_in_account(*, account_id: str):
|
||||||
return AccountService.load_user(account_id)
|
return AccountService.load_user(account_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -258,10 +312,6 @@ class AccountService:
|
||||||
return TokenManager.get_token_data(token, "reset_password")
|
return TokenManager.get_token_data(token, "reset_password")
|
||||||
|
|
||||||
|
|
||||||
def _get_login_cache_key(*, account_id: str, token: str):
|
|
||||||
return f"account_login:{account_id}:{token}"
|
|
||||||
|
|
||||||
|
|
||||||
class TenantService:
|
class TenantService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_tenant(name: str) -> Tenant:
|
def create_tenant(name: str) -> Tenant:
|
||||||
|
@ -698,3 +748,8 @@ class RegisterService:
|
||||||
|
|
||||||
invitation = json.loads(data)
|
invitation = json.loads(data)
|
||||||
return invitation
|
return invitation
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_refresh_token(length: int = 64):
|
||||||
|
token = secrets.token_hex(length)
|
||||||
|
return token
|
||||||
|
|
|
@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
|
||||||
# The default value is 300 seconds.
|
# The default value is 300 seconds.
|
||||||
FILES_ACCESS_TIMEOUT=300
|
FILES_ACCESS_TIMEOUT=300
|
||||||
|
|
||||||
|
# Access token expiration time in minutes
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env
|
||||||
REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-}
|
REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-}
|
||||||
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
|
REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
|
||||||
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
|
REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
||||||
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
||||||
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
|
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
|
||||||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user