From d49ac1e4ac4e8a715d0eb32be002ac019829b15b Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:21:20 +0800 Subject: [PATCH] Feature/use jwt in web (#533) Co-authored-by: crazywoola Co-authored-by: StyleZhang --- .gitignore | 1 + api/app.py | 2 +- api/controllers/web/__init__.py | 2 +- api/controllers/web/passport.py | 64 +++++++++++++ api/controllers/web/wraps.py | 94 ++++--------------- api/libs/passport.py | 20 ++++ api/requirements.txt | 3 +- web/app/components/share/chat/index.tsx | 19 +++- .../share/text-generation/index.tsx | 10 +- web/app/components/share/utils.ts | 18 ++++ web/service/base.ts | 12 ++- web/service/share.ts | 6 ++ 12 files changed, 161 insertions(+), 90 deletions(-) create mode 100644 api/controllers/web/passport.py create mode 100644 api/libs/passport.py create mode 100644 web/app/components/share/utils.ts diff --git a/.gitignore b/.gitignore index 957c39cdd1..3a7773ccbe 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.conda/ # Spyder project settings .spyderproject diff --git a/api/app.py b/api/app.py index 3beb2cf706..1b67c0cbd2 100644 --- a/api/app.py +++ b/api/app.py @@ -155,7 +155,7 @@ def register_blueprints(app): resources={ r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, supports_credentials=True, - allow_headers=['Content-Type', 'Authorization'], + allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], expose_headers=['X-Version', 'X-Env'] ) diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index c419453665..0808dce5c4 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api') api = ExternalApi(bp) -from . import completion, app, conversation, message, site, saved_message, audio +from . import completion, app, conversation, message, site, saved_message, audio, passport diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py new file mode 100644 index 0000000000..219f6e731f --- /dev/null +++ b/api/controllers/web/passport.py @@ -0,0 +1,64 @@ +# -*- coding:utf-8 -*- +import uuid +from controllers.web import api +from flask_restful import Resource +from flask import request +from werkzeug.exceptions import Unauthorized, NotFound +from models.model import Site, EndUser, App +from extensions.ext_database import db +from libs.passport import PassportService + +class PassportResource(Resource): + """Base resource for passport.""" + def get(self): + app_id = request.headers.get('X-App-Code') + if app_id is None: + raise Unauthorized('X-App-Code header is missing.') + + # get site from db and check if it is normal + site = db.session.query(Site).filter( + Site.code == app_id, + Site.status == 'normal' + ).first() + if not site: + raise NotFound() + # get app from db and check if it is normal and enable_site + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model or app_model.status != 'normal' or not app_model.enable_site: + raise NotFound() + + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type='browser', + is_anonymous=True, + session_id=generate_session_id(), + ) + db.session.add(end_user) + db.session.commit() + + payload = { + "iss": site.app_id, + 'sub': 'Web API Passport', + 'app_id': site.app_id, + 'end_user_id': end_user.id, + } + + tk = PassportService().issue(payload) + + return { + 'access_token': tk, + } + +api.add_resource(PassportResource, '/passport') + +def generate_session_id(): + """ + Generate a unique session ID. + """ + while True: + session_id = str(uuid.uuid4()) + existing_count = db.session.query(EndUser) \ + .filter(EndUser.session_id == session_id).count() + if existing_count == 0: + return session_id diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 9321c427c2..4d48a190f2 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,110 +1,48 @@ # -*- coding:utf-8 -*- -import uuid from functools import wraps -from flask import request, session +from flask import request from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized from extensions.ext_database import db -from models.model import App, Site, EndUser +from models.model import App, EndUser +from libs.passport import PassportService - -def validate_token(view=None): +def validate_jwt_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - site = validate_and_get_site() - - app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model: - raise NotFound() - - if app_model.status != 'normal': - raise NotFound() - - if not app_model.enable_site: - raise NotFound() - - end_user = create_or_update_end_user_for_session(app_model) + app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) return decorated - if view: return decorator(view) return decorator - -def validate_and_get_site(): - """ - Validate and get API token. - """ +def decode_jwt_token(): auth_header = request.headers.get('Authorization') if auth_header is None: raise Unauthorized('Authorization header is missing.') if ' ' not in auth_header: raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - - auth_scheme, auth_token = auth_header.split(None, 1) + + auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != 'bearer': raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - - site = db.session.query(Site).filter( - Site.code == auth_token, - Site.status == 'normal' - ).first() - - if not site: + decoded = PassportService().verify(tk) + app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + if not app_model: + raise NotFound() + end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + if not end_user: raise NotFound() - return site - - -def create_or_update_end_user_for_session(app_model): - """ - Create or update session terminal based on session ID. - """ - if 'session_id' not in session: - session['session_id'] = generate_session_id() - - session_id = session.get('session_id') - end_user = db.session.query(EndUser) \ - .filter( - EndUser.session_id == session_id, - EndUser.type == 'browser' - ).first() - - if end_user is None: - end_user = EndUser( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - type='browser', - is_anonymous=True, - session_id=session_id - ) - db.session.add(end_user) - db.session.commit() - - return end_user - - -def generate_session_id(): - """ - Generate a unique session ID. - """ - count = 1 - session_id = '' - while count != 0: - session_id = str(uuid.uuid4()) - count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() - - return session_id - + return app_model, end_user class WebApiResource(Resource): - method_decorators = [validate_token] + method_decorators = [validate_jwt_token] diff --git a/api/libs/passport.py b/api/libs/passport.py new file mode 100644 index 0000000000..c3bd9e566f --- /dev/null +++ b/api/libs/passport.py @@ -0,0 +1,20 @@ +# -*- coding:utf-8 -*- +import jwt +from werkzeug.exceptions import Unauthorized +from flask import current_app +class PassportService: + def __init__(self): + self.sk = current_app.config.get('SECRET_KEY') + + def issue(self, payload): + return jwt.encode(payload, self.sk, algorithm='HS256') + + def verify(self, token): + try: + return jwt.decode(token, self.sk, algorithms=['HS256']) + except jwt.exceptions.InvalidSignatureError: + raise Unauthorized('Invalid token signature.') + except jwt.exceptions.DecodeError: + raise Unauthorized('Invalid token.') + except jwt.exceptions.ExpiredSignatureError: + raise Unauthorized('Token has expired.') diff --git a/api/requirements.txt b/api/requirements.txt index e129eb4b37..fe554653b0 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -32,4 +32,5 @@ redis~=4.5.4 openpyxl==3.1.2 chardet~=5.1.0 docx2txt==0.8 -pypdfium2==4.16.0 \ No newline at end of file +pypdfium2==4.16.0 +pyjwt~=2.6.0 \ No newline at end of file diff --git a/web/app/components/share/chat/index.tsx b/web/app/components/share/chat/index.tsx index 6b81454772..454102ba8b 100644 --- a/web/app/components/share/chat/index.tsx +++ b/web/app/components/share/chat/index.tsx @@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector' import produce from 'immer' import { useBoolean, useGetState } from 'ahooks' import AppUnavailable from '../../base/app-unavailable' +import { checkOrSetAccessToken } from '../utils' import useConversation from './hooks/use-conversation' import s from './style.module.css' import { ToastContext } from '@/app/components/base/toast' import Sidebar from '@/app/components/share/chat/sidebar' import ConfigSence from '@/app/components/share/chat/config-scence' import Header from '@/app/components/share/header' -import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share' +import { + delConversation, + fetchAppInfo, + fetchAppParams, + fetchChatList, + fetchConversations, + fetchSuggestedQuestions, + pinConversation, + sendChatMessage, + stopChatMessageResponding, + unpinConversation, + updateFeedback, +} from '@/service/share' import type { ConversationItem, SiteInfo } from '@/models/share' import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug' import type { Feedbacktype, IChatItem } from '@/app/components/app/chat' @@ -296,7 +309,9 @@ const Main: FC = ({ return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100) } - const fetchInitData = () => { + const fetchInitData = async () => { + await checkOrSetAccessToken() + return Promise.all([isInstalledApp ? { app_id: installedAppInfo?.id, diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index 5b2a077dcc..01bbcea5e1 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks' import { XMarkIcon } from '@heroicons/react/24/outline' import TabHeader from '../../base/tab-header' import Button from '../../base/button' +import { checkOrSetAccessToken } from '../utils' import s from './style.module.css' import RunBatch from './run-batch' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' @@ -76,9 +77,6 @@ const TextGeneration: FC = ({ const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id) setSavedMessages(res.data) } - useEffect(() => { - fetchSavedMessage() - }, []) const handleSaveMessage = async (messageId: string) => { await saveMessage(messageId, isInstalledApp, installedAppInfo?.id) notify({ type: 'success', message: t('common.api.saved') }) @@ -256,7 +254,9 @@ const TextGeneration: FC = ({ setAllTaskList(newAllTaskList) } - const fetchInitData = () => { + const fetchInitData = async () => { + await checkOrSetAccessToken() + return Promise.all([isInstalledApp ? { app_id: installedAppInfo?.id, @@ -267,7 +267,7 @@ const TextGeneration: FC = ({ }, plan: 'basic', } - : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)]) + : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()]) } useEffect(() => { diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts new file mode 100644 index 0000000000..7c10644e2a --- /dev/null +++ b/web/app/components/share/utils.ts @@ -0,0 +1,18 @@ +import { fetchAccessToken } from '@/service/share' + +export const checkOrSetAccessToken = async () => { + const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] + const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) + let accessTokenJson = { [sharedToken]: '' } + try { + accessTokenJson = JSON.parse(accessToken) + } + catch (e) { + + } + if (!accessTokenJson[sharedToken]) { + const res = await fetchAccessToken(sharedToken) + accessTokenJson[sharedToken] = res.access_token + localStorage.setItem('token', JSON.stringify(accessTokenJson)) + } +} diff --git a/web/service/base.ts b/web/service/base.ts index dfdfae5331..019c8d5d0e 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -142,7 +142,15 @@ const baseFetch = ( const options = Object.assign({}, baseOptions, fetchOptions) if (isPublicAPI) { const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - options.headers.set('Authorization', `bearer ${sharedToken}`) + const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) + let accessTokenJson = { [sharedToken]: '' } + try { + accessTokenJson = JSON.parse(accessToken) + } + catch (e) { + + } + options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`) } if (deleteContentType) { @@ -194,7 +202,7 @@ const baseFetch = ( case 401: { if (isPublicAPI) { Toast.notify({ type: 'error', message: 'Invalid token' }) - return + return bodyJson.then((data: any) => Promise.reject(data)) } const loginUrl = `${globalThis.location.origin}/signin` if (IS_CE_EDITION) { diff --git a/web/service/share.ts b/web/service/share.ts index df01abc3eb..abdba8cc05 100644 --- a/web/service/share.ts +++ b/web/service/share.ts @@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => { return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }> } + +export const fetchAccessToken = async (appCode: string) => { + const headers = new Headers() + headers.append('X-App-Code', appCode) + return get('/passport', { headers }) as Promise<{ access_token: string }> +}