mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Feature/use jwt in web (#533)
Co-authored-by: crazywoola <li.zheng@dentsplysirona.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
57de19a5ca
commit
d49ac1e4ac
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -109,6 +109,7 @@ venv/
|
|||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
|
|
@ -155,7 +155,7 @@ def register_blueprints(app):
|
|||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
|
|
|
@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
|||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message, audio
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
|
|
64
api/controllers/web/passport.py
Normal file
64
api/controllers/web/passport.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from controllers.web import api
|
||||
from flask_restful import Resource
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Unauthorized, NotFound
|
||||
from models.model import Site, EndUser, App
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
def get(self):
|
||||
app_id = request.headers.get('X-App-Code')
|
||||
if app_id is None:
|
||||
raise Unauthorized('X-App-Code header is missing.')
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == app_id,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
# get app from db and check if it is normal and enable_site
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
'sub': 'Web API Passport',
|
||||
'app_id': site.app_id,
|
||||
'end_user_id': end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
'access_token': tk,
|
||||
}
|
||||
|
||||
api.add_resource(PassportResource, '/passport')
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
while True:
|
||||
session_id = str(uuid.uuid4())
|
||||
existing_count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
if existing_count == 0:
|
||||
return session_id
|
|
@ -1,45 +1,27 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from flask import request, session
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Site, EndUser
|
||||
from models.model import App, EndUser
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
def validate_token(view=None):
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
site = validate_and_get_site()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
if app_model.status != 'normal':
|
||||
raise NotFound()
|
||||
|
||||
if not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = create_or_update_end_user_for_session(app_model)
|
||||
app_model, end_user = decode_jwt_token()
|
||||
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_site():
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
def decode_jwt_token():
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
|
@ -47,64 +29,20 @@ def validate_and_get_site():
|
|||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == auth_token,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not site:
|
||||
decoded = PassportService().verify(tk)
|
||||
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
def create_or_update_end_user_for_session(app_model):
|
||||
"""
|
||||
Create or update session terminal based on session ID.
|
||||
"""
|
||||
if 'session_id' not in session:
|
||||
session['session_id'] = generate_session_id()
|
||||
|
||||
session_id = session.get('session_id')
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.type == 'browser'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=session_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
count = 1
|
||||
session_id = ''
|
||||
while count != 0:
|
||||
session_id = str(uuid.uuid4())
|
||||
count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
|
||||
return session_id
|
||||
|
||||
return app_model, end_user
|
||||
|
||||
class WebApiResource(Resource):
|
||||
method_decorators = [validate_token]
|
||||
method_decorators = [validate_jwt_token]
|
||||
|
|
20
api/libs/passport.py
Normal file
20
api/libs/passport.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import jwt
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from flask import current_app
|
||||
class PassportService:
|
||||
def __init__(self):
|
||||
self.sk = current_app.config.get('SECRET_KEY')
|
||||
|
||||
def issue(self, payload):
|
||||
return jwt.encode(payload, self.sk, algorithm='HS256')
|
||||
|
||||
def verify(self, token):
|
||||
try:
|
||||
return jwt.decode(token, self.sk, algorithms=['HS256'])
|
||||
except jwt.exceptions.InvalidSignatureError:
|
||||
raise Unauthorized('Invalid token signature.')
|
||||
except jwt.exceptions.DecodeError:
|
||||
raise Unauthorized('Invalid token.')
|
||||
except jwt.exceptions.ExpiredSignatureError:
|
||||
raise Unauthorized('Token has expired.')
|
|
@ -33,3 +33,4 @@ openpyxl==3.1.2
|
|||
chardet~=5.1.0
|
||||
docx2txt==0.8
|
||||
pypdfium2==4.16.0
|
||||
pyjwt~=2.6.0
|
|
@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
|
|||
import produce from 'immer'
|
||||
import { useBoolean, useGetState } from 'ahooks'
|
||||
import AppUnavailable from '../../base/app-unavailable'
|
||||
import { checkOrSetAccessToken } from '../utils'
|
||||
import useConversation from './hooks/use-conversation'
|
||||
import s from './style.module.css'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Sidebar from '@/app/components/share/chat/sidebar'
|
||||
import ConfigSence from '@/app/components/share/chat/config-scence'
|
||||
import Header from '@/app/components/share/header'
|
||||
import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
|
||||
import {
|
||||
delConversation,
|
||||
fetchAppInfo,
|
||||
fetchAppParams,
|
||||
fetchChatList,
|
||||
fetchConversations,
|
||||
fetchSuggestedQuestions,
|
||||
pinConversation,
|
||||
sendChatMessage,
|
||||
stopChatMessageResponding,
|
||||
unpinConversation,
|
||||
updateFeedback,
|
||||
} from '@/service/share'
|
||||
import type { ConversationItem, SiteInfo } from '@/models/share'
|
||||
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
|
||||
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
|
||||
|
@ -296,7 +309,9 @@ const Main: FC<IMainProps> = ({
|
|||
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
|
||||
}
|
||||
|
||||
const fetchInitData = () => {
|
||||
const fetchInitData = async () => {
|
||||
await checkOrSetAccessToken()
|
||||
|
||||
return Promise.all([isInstalledApp
|
||||
? {
|
||||
app_id: installedAppInfo?.id,
|
||||
|
|
|
@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
|
|||
import { XMarkIcon } from '@heroicons/react/24/outline'
|
||||
import TabHeader from '../../base/tab-header'
|
||||
import Button from '../../base/button'
|
||||
import { checkOrSetAccessToken } from '../utils'
|
||||
import s from './style.module.css'
|
||||
import RunBatch from './run-batch'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
|
@ -76,9 +77,6 @@ const TextGeneration: FC<IMainProps> = ({
|
|||
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
|
||||
setSavedMessages(res.data)
|
||||
}
|
||||
useEffect(() => {
|
||||
fetchSavedMessage()
|
||||
}, [])
|
||||
const handleSaveMessage = async (messageId: string) => {
|
||||
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
|
||||
notify({ type: 'success', message: t('common.api.saved') })
|
||||
|
@ -256,7 +254,9 @@ const TextGeneration: FC<IMainProps> = ({
|
|||
setAllTaskList(newAllTaskList)
|
||||
}
|
||||
|
||||
const fetchInitData = () => {
|
||||
const fetchInitData = async () => {
|
||||
await checkOrSetAccessToken()
|
||||
|
||||
return Promise.all([isInstalledApp
|
||||
? {
|
||||
app_id: installedAppInfo?.id,
|
||||
|
@ -267,7 +267,7 @@ const TextGeneration: FC<IMainProps> = ({
|
|||
},
|
||||
plan: 'basic',
|
||||
}
|
||||
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
|
||||
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
|
|
18
web/app/components/share/utils.ts
Normal file
18
web/app/components/share/utils.ts
Normal file
|
@ -0,0 +1,18 @@
|
|||
import { fetchAccessToken } from '@/service/share'
|
||||
|
||||
export const checkOrSetAccessToken = async () => {
|
||||
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
|
||||
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
|
||||
let accessTokenJson = { [sharedToken]: '' }
|
||||
try {
|
||||
accessTokenJson = JSON.parse(accessToken)
|
||||
}
|
||||
catch (e) {
|
||||
|
||||
}
|
||||
if (!accessTokenJson[sharedToken]) {
|
||||
const res = await fetchAccessToken(sharedToken)
|
||||
accessTokenJson[sharedToken] = res.access_token
|
||||
localStorage.setItem('token', JSON.stringify(accessTokenJson))
|
||||
}
|
||||
}
|
|
@ -142,7 +142,15 @@ const baseFetch = (
|
|||
const options = Object.assign({}, baseOptions, fetchOptions)
|
||||
if (isPublicAPI) {
|
||||
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
|
||||
options.headers.set('Authorization', `bearer ${sharedToken}`)
|
||||
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
|
||||
let accessTokenJson = { [sharedToken]: '' }
|
||||
try {
|
||||
accessTokenJson = JSON.parse(accessToken)
|
||||
}
|
||||
catch (e) {
|
||||
|
||||
}
|
||||
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
|
||||
}
|
||||
|
||||
if (deleteContentType) {
|
||||
|
@ -194,7 +202,7 @@ const baseFetch = (
|
|||
case 401: {
|
||||
if (isPublicAPI) {
|
||||
Toast.notify({ type: 'error', message: 'Invalid token' })
|
||||
return
|
||||
return bodyJson.then((data: any) => Promise.reject(data))
|
||||
}
|
||||
const loginUrl = `${globalThis.location.origin}/signin`
|
||||
if (IS_CE_EDITION) {
|
||||
|
|
|
@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
|
|||
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
|
||||
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
|
||||
}
|
||||
|
||||
export const fetchAccessToken = async (appCode: string) => {
|
||||
const headers = new Headers()
|
||||
headers.append('X-App-Code', appCode)
|
||||
return get('/passport', { headers }) as Promise<{ access_token: string }>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user