Feature/use jwt in web (#533)

Co-authored-by: crazywoola <li.zheng@dentsplysirona.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
crazywoola 2023-07-11 15:21:20 +08:00 committed by GitHub
parent 57de19a5ca
commit d49ac1e4ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 161 additions and 90 deletions

1
.gitignore vendored
View File

@ -109,6 +109,7 @@ venv/
ENV/
env.bak/
venv.bak/
.conda/
# Spyder project settings
.spyderproject

View File

@ -155,7 +155,7 @@ def register_blueprints(app):
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)

View File

@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
from . import completion, app, conversation, message, site, saved_message, audio
from . import completion, app, conversation, message, site, saved_message, audio, passport

View File

@ -0,0 +1,64 @@
# -*- coding:utf-8 -*-
import uuid
from controllers.web import api
from flask_restful import Resource
from flask import request
from werkzeug.exceptions import Unauthorized, NotFound
from models.model import Site, EndUser, App
from extensions.ext_database import db
from libs.passport import PassportService
class PassportResource(Resource):
"""Base resource for passport."""
def get(self):
app_id = request.headers.get('X-App-Code')
if app_id is None:
raise Unauthorized('X-App-Code header is missing.')
# get site from db and check if it is normal
site = db.session.query(Site).filter(
Site.code == app_id,
Site.status == 'normal'
).first()
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
raise NotFound()
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,
'sub': 'Web API Passport',
'app_id': site.app_id,
'end_user_id': end_user.id,
}
tk = PassportService().issue(payload)
return {
'access_token': tk,
}
api.add_resource(PassportResource, '/passport')
def generate_session_id():
"""
Generate a unique session ID.
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
if existing_count == 0:
return session_id

View File

@ -1,45 +1,27 @@
# -*- coding:utf-8 -*-
import uuid
from functools import wraps
from flask import request, session
from flask import request
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from models.model import App, Site, EndUser
from models.model import App, EndUser
from libs.passport import PassportService
def validate_token(view=None):
def validate_jwt_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
site = validate_and_get_site()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model:
raise NotFound()
if app_model.status != 'normal':
raise NotFound()
if not app_model.enable_site:
raise NotFound()
end_user = create_or_update_end_user_for_session(app_model)
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def validate_and_get_site():
"""
Validate and get API token.
"""
def decode_jwt_token():
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized('Authorization header is missing.')
@ -47,64 +29,20 @@ def validate_and_get_site():
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
site = db.session.query(Site).filter(
Site.code == auth_token,
Site.status == 'normal'
).first()
if not site:
decoded = PassportService().verify(tk)
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
if not app_model:
raise NotFound()
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
if not end_user:
raise NotFound()
return site
def create_or_update_end_user_for_session(app_model):
"""
Create or update session terminal based on session ID.
"""
if 'session_id' not in session:
session['session_id'] = generate_session_id()
session_id = session.get('session_id')
end_user = db.session.query(EndUser) \
.filter(
EndUser.session_id == session_id,
EndUser.type == 'browser'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=session_id
)
db.session.add(end_user)
db.session.commit()
return end_user
def generate_session_id():
"""
Generate a unique session ID.
"""
count = 1
session_id = ''
while count != 0:
session_id = str(uuid.uuid4())
count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
return session_id
return app_model, end_user
class WebApiResource(Resource):
method_decorators = [validate_token]
method_decorators = [validate_jwt_token]

20
api/libs/passport.py Normal file
View File

@ -0,0 +1,20 @@
# -*- coding:utf-8 -*-
import jwt
from werkzeug.exceptions import Unauthorized
from flask import current_app
class PassportService:
def __init__(self):
self.sk = current_app.config.get('SECRET_KEY')
def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm='HS256')
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=['HS256'])
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized('Invalid token signature.')
except jwt.exceptions.DecodeError:
raise Unauthorized('Invalid token.')
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized('Token has expired.')

View File

@ -33,3 +33,4 @@ openpyxl==3.1.2
chardet~=5.1.0
docx2txt==0.8
pypdfium2==4.16.0
pyjwt~=2.6.0

View File

@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
import produce from 'immer'
import { useBoolean, useGetState } from 'ahooks'
import AppUnavailable from '../../base/app-unavailable'
import { checkOrSetAccessToken } from '../utils'
import useConversation from './hooks/use-conversation'
import s from './style.module.css'
import { ToastContext } from '@/app/components/base/toast'
import Sidebar from '@/app/components/share/chat/sidebar'
import ConfigSence from '@/app/components/share/chat/config-scence'
import Header from '@/app/components/share/header'
import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
import {
delConversation,
fetchAppInfo,
fetchAppParams,
fetchChatList,
fetchConversations,
fetchSuggestedQuestions,
pinConversation,
sendChatMessage,
stopChatMessageResponding,
unpinConversation,
updateFeedback,
} from '@/service/share'
import type { ConversationItem, SiteInfo } from '@/models/share'
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
@ -296,7 +309,9 @@ const Main: FC<IMainProps> = ({
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
}
const fetchInitData = () => {
const fetchInitData = async () => {
await checkOrSetAccessToken()
return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,

View File

@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
import { XMarkIcon } from '@heroicons/react/24/outline'
import TabHeader from '../../base/tab-header'
import Button from '../../base/button'
import { checkOrSetAccessToken } from '../utils'
import s from './style.module.css'
import RunBatch from './run-batch'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
@ -76,9 +77,6 @@ const TextGeneration: FC<IMainProps> = ({
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
setSavedMessages(res.data)
}
useEffect(() => {
fetchSavedMessage()
}, [])
const handleSaveMessage = async (messageId: string) => {
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
notify({ type: 'success', message: t('common.api.saved') })
@ -256,7 +254,9 @@ const TextGeneration: FC<IMainProps> = ({
setAllTaskList(newAllTaskList)
}
const fetchInitData = () => {
const fetchInitData = async () => {
await checkOrSetAccessToken()
return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,
@ -267,7 +267,7 @@ const TextGeneration: FC<IMainProps> = ({
},
plan: 'basic',
}
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
}
useEffect(() => {

View File

@ -0,0 +1,18 @@
import { fetchAccessToken } from '@/service/share'
export const checkOrSetAccessToken = async () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch (e) {
}
if (!accessTokenJson[sharedToken]) {
const res = await fetchAccessToken(sharedToken)
accessTokenJson[sharedToken] = res.access_token
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
}

View File

@ -142,7 +142,15 @@ const baseFetch = (
const options = Object.assign({}, baseOptions, fetchOptions)
if (isPublicAPI) {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
options.headers.set('Authorization', `bearer ${sharedToken}`)
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch (e) {
}
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
}
if (deleteContentType) {
@ -194,7 +202,7 @@ const baseFetch = (
case 401: {
if (isPublicAPI) {
Toast.notify({ type: 'error', message: 'Invalid token' })
return
return bodyJson.then((data: any) => Promise.reject(data))
}
const loginUrl = `${globalThis.location.origin}/signin`
if (IS_CE_EDITION) {

View File

@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
}
export const fetchAccessToken = async (appCode: string) => {
const headers = new Headers()
headers.append('X-App-Code', appCode)
return get('/passport', { headers }) as Promise<{ access_token: string }>
}