Feat/api jwt (#1212)

This commit is contained in:
zxhlyh 2023-09-25 12:49:16 +08:00 committed by GitHub
parent c40ee7e629
commit 227f9fb77d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 142 additions and 339 deletions

View File

@ -50,24 +50,6 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Cookie configuration
COOKIE_HTTPONLY=true
COOKIE_SAMESITE=None
COOKIE_SECURE=true
# Session configuration
SESSION_PERMANENT=true
SESSION_USE_SIGNER=true
## support redis, sqlalchemy
SESSION_TYPE=redis
# session redis configuration
SESSION_REDIS_HOST=localhost
SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant
VECTOR_STORE=weaviate

View File

@ -1,8 +1,7 @@
# -*- coding:utf-8 -*-
import os
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
@ -12,12 +11,11 @@ import logging
import json
import threading
from flask import Flask, request, Response, session
import flask_login
from flask import Flask, request, Response
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager
@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
from events import event_handlers
# DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig
from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
from services.account_service import TenantService
from services.account_service import AccountService
from libs.passport import PassportService
import warnings
warnings.simplefilter("ignore", ResourceWarning)
@ -85,81 +81,33 @@ def initialize_extensions(app):
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
# Flask-Login configuration
@login_manager.user_loader
def load_user(user_id):
"""Load user based on the user_id."""
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
account = db.session.query(Account).filter(Account.id == account_id).first()
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)
return account
return AccountService.load_user(user_id)
else:
return None
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
@ -216,6 +164,7 @@ if app.config['TESTING']:
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response

View File

@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
dotenv.load_dotenv()
DEFAULTS = {
'COOKIE_HTTPONLY': 'True',
'COOKIE_SECURE': 'True',
'COOKIE_SAMESITE': 'None',
'DB_USERNAME': 'postgres',
'DB_PASSWORD': '',
'DB_HOST': 'localhost',
@ -22,10 +19,6 @@ DEFAULTS = {
'REDIS_PORT': '6379',
'REDIS_DB': '0',
'REDIS_USE_SSL': 'False',
'SESSION_REDIS_HOST': 'localhost',
'SESSION_REDIS_PORT': '6379',
'SESSION_REDIS_DB': '2',
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
@ -36,9 +29,6 @@ DEFAULTS = {
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'SESSION_TYPE': 'sqlalchemy',
'SESSION_PERMANENT': 'True',
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
@ -115,20 +105,6 @@ class Config:
# Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY')
# cookie settings
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
# session settings, only support sqlalchemy, redis
self.SESSION_TYPE = get_env('SESSION_TYPE')
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
# redis settings
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
@ -137,14 +113,6 @@ class Config:
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# session redis settings
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
# storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')

View File

@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
import services
from controllers.console import api
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required
from libs.helper import email
from libs.password import valid_password
@ -37,12 +36,12 @@ class LoginApi(Resource):
except Exception:
pass
flask_login.login_user(account, remember=args['remember_me'])
AccountService.update_last_login(account, request)
# todo: return the user info
token = AccountService.get_account_jwt_token(account)
return {'result': 'success'}
return {'result': 'success', 'data': token}
class LogoutApi(Resource):

View File

@ -2,9 +2,8 @@ import logging
from datetime import datetime
from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask import request, redirect, current_app
from flask_restful import Resource
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
@ -75,12 +74,11 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.utcnow()
db.session.commit()
# login user
session.clear()
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
token = AccountService.get_account_jwt_token(account)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@ -1,7 +1,6 @@
# -*- coding:utf-8 -*-
from functools import wraps
import flask_login
from flask import request, current_app
from flask_restful import Resource, reqparse
@ -58,9 +57,6 @@ class SetupApi(Resource):
)
setup()
# Login
flask_login.login_user(account)
AccountService.update_last_login(account, request)
return {'result': 'success'}, 201

View File

@ -1,11 +1,10 @@
import os
from functools import wraps
import flask_login
from flask import current_app
from flask import g
from flask import has_request_context
from flask import request
from flask import request, session
from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized

View File

@ -1,174 +0,0 @@
import redis
from redis.connection import SSLConnection, Connection
from flask import request
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
from flask_session.sessions import total_seconds
from itsdangerous import want_bytes
from extensions.ext_database import db
sess = Session()
def init_app(app):
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
app,
db,
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
session_type = app.config.get('SESSION_TYPE')
if session_type == 'sqlalchemy':
app.session_interface = sqlalchemy_session_interface
elif session_type == 'redis':
connection_class = Connection
if app.config.get('SESSION_REDIS_USE_SSL', False):
connection_class = SSLConnection
sess_redis_client = redis.Redis()
sess_redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
'port': app.config.get('SESSION_REDIS_PORT', 6379),
'username': app.config.get('SESSION_REDIS_USERNAME', None),
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
'db': app.config.get('SESSION_REDIS_DB', 2),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)
app.extensions['session_redis'] = sess_redis_client
app.session_interface = CustomRedisSessionInterface(
sess_redis_client,
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
def __init__(
self,
app,
db,
table,
key_prefix,
use_signer=False,
permanent=True,
sequence=None,
autodelete=False,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
self.db = db
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.autodelete = autodelete
self.sequence = sequence
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
class Session(self.db.Model):
__tablename__ = table
if sequence:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, primary_key=True
)
session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)
def __init__(self, session_id, data, expiry):
self.session_id = session_id
self.data = data
self.expiry = expiry
def __repr__(self):
return f"<Session data {self.data}>"
self.sql_session_model = Session
def save_session(self, *args, **kwargs):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
return super().save_session(*args, **kwargs)
class CustomRedisSessionInterface(RedisSessionInterface):
def save_session(self, app, session, response):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
if not self.should_set_cookie(app, session):
return
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
if not session:
if session.modified:
self.redis.delete(self.key_prefix + session.sid)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Modification case. There are upsides and downsides to
# emitting a set-cookie header each request. The behavior
# is controlled by the :meth:`should_set_cookie` method
# which performs a quick check to figure out if the cookie
# should be set or not. This is controlled by the
# SESSION_REFRESH_EACH_REQUEST config flag as well as
# the permanent flag on the session itself.
# if not self.should_set_cookie(app, session):
# return
conditional_cookie_kwargs = {}
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
if self.has_same_site_capability:
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
expires = self.get_expiration_time(app, session)
if session.permanent:
value = self.serializer.dumps(dict(session))
if value is not None:
self.redis.setex(
name=self.key_prefix + session.sid,
value=value,
time=total_seconds(app.permanent_session_lifetime),
)
if self.use_signer:
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
else:
session_id = session.sid
response.set_cookie(
app.config["SESSION_COOKIE_NAME"],
session_id,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
**conditional_cookie_kwargs,
)

View File

@ -4,11 +4,12 @@ import json
import logging
import secrets
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from hashlib import sha256
from typing import Optional
from flask import session
from werkzeug.exceptions import Forbidden, Unauthorized
from flask import session, current_app
from sqlalchemy import func
from events.tenant_event import tenant_was_created
@ -19,16 +20,82 @@ from services.errors.account import AccountLoginError, CurrentPasswordIncorrectE
from libs.helper import get_remote_ip
from libs.password import compare_password, hash_password
from libs.rsa import generate_key_pair
from libs.passport import PassportService
from models.account import *
from tasks.mail_invite_member_task import send_invite_member_mail_task
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
class AccountService:
@staticmethod
def load_user(account_id: int) -> Account:
def load_user(user_id: str) -> Account:
# todo: used by flask_login
pass
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
return account
@staticmethod
def get_account_jwt_token(account):
payload = {
"user_id": account.id,
"exp": datetime.utcnow() + timedelta(days=30),
"iss": current_app.config['EDITION'],
"sub": 'Console API Passport',
}
token = PassportService().issue(payload)
return token
@staticmethod
def authenticate(email: str, password: str) -> Account:

View File

@ -49,15 +49,6 @@ services:
REDIS_USE_SSL: 'false'
# use redis db 0 for redis cache
REDIS_DB: 0
# The configurations of session, Supported values are `sqlalchemy`. `redis`
SESSION_TYPE: redis
SESSION_REDIS_HOST: redis
SESSION_REDIS_PORT: 6379
SESSION_REDIS_USERNAME: ''
SESSION_REDIS_PASSWORD: difyai123456
SESSION_REDIS_USE_SSL: 'false'
# use redis db 2 for session store
SESSION_REDIS_DB: 2
# The configurations of celery broker.
# Use redis as the broker, and redis db 1 for celery broker.
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
@ -76,10 +67,6 @@ services:
# If you want to enable cross-origin support,
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
#
# For **production** purposes, please set `SameSite=Lax, Secure=true, HttpOnly=true`.
COOKIE_HTTPONLY: 'true'
COOKIE_SAMESITE: 'Lax'
COOKIE_SECURE: 'false'
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
STORAGE_TYPE: local
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.

View File

@ -1,7 +1,9 @@
'use client'
import { SWRConfig } from 'swr'
import { useEffect, useState } from 'react'
import type { ReactNode } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
type SwrInitorProps = {
children: ReactNode
@ -9,13 +11,32 @@ type SwrInitorProps = {
const SwrInitor = ({
children,
}: SwrInitorProps) => {
return (
const router = useRouter()
const searchParams = useSearchParams()
const consoleToken = searchParams.get('console_token')
const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
const [init, setInit] = useState(false)
useEffect(() => {
if (!(consoleToken || consoleTokenFromLocalStorage))
router.replace('/signin')
if (consoleToken) {
localStorage?.setItem('console_token', consoleToken!)
router.replace('/apps', { forceOptimisticNavigation: false })
}
setInit(true)
}, [])
return init
? (
<SWRConfig value={{
shouldRetryOnError: false,
}}>
{children}
</SWRConfig>
)
: null
}
export default SwrInitor

View File

@ -8,6 +8,10 @@ import I18n from '@/context/i18n'
const Header = () => {
const { locale, setLocaleOnClient } = useContext(I18n)
if (localStorage?.getItem('console_token'))
localStorage.removeItem('console_token')
return <div className='flex items-center justify-between p-6 w-full'>
<div className={style.logo}></div>
<Select

View File

@ -89,7 +89,7 @@ const NormalForm = () => {
}
try {
setIsLoading(true)
await login({
const res = await login({
url: '/login',
body: {
email,
@ -97,7 +97,8 @@ const NormalForm = () => {
remember_me: true,
},
})
router.push('/apps')
localStorage.setItem('console_token', res.data)
router.replace('/apps')
}
finally {
setIsLoading(false)

View File

@ -179,6 +179,10 @@ const baseFetch = <T>(
}
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
}
else {
const accessToken = localStorage.getItem('console_token') || ''
options.headers.set('Authorization', `Bearer ${accessToken}`)
}
if (deleteContentType) {
options.headers.delete('Content-Type')
@ -292,7 +296,9 @@ export const upload = (options: any): Promise<any> => {
const defaultOptions = {
method: 'POST',
url: `${API_PREFIX}/files/upload`,
headers: {},
headers: {
Authorization: `Bearer ${localStorage.getItem('console_token') || ''}`,
},
data: {},
}
options = {

View File

@ -15,8 +15,8 @@ import type {
} from '@/models/app'
import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations'
export const login: Fetcher<CommonResponse, { url: string; body: Record<string, any> }> = ({ url, body }) => {
return post<CommonResponse>(url, { body })
export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => {
return post(url, { body }) as Promise<CommonResponse & { data: string }>
}
export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {