From d49ac1e4ac4e8a715d0eb32be002ac019829b15b Mon Sep 17 00:00:00 2001
From: crazywoola <100913391+crazywoola@users.noreply.github.com>
Date: Tue, 11 Jul 2023 15:21:20 +0800
Subject: [PATCH] Feature/use jwt in web (#533)
Co-authored-by: crazywoola
Co-authored-by: StyleZhang
---
.gitignore | 1 +
api/app.py | 2 +-
api/controllers/web/__init__.py | 2 +-
api/controllers/web/passport.py | 64 +++++++++++++
api/controllers/web/wraps.py | 94 ++++---------------
api/libs/passport.py | 20 ++++
api/requirements.txt | 3 +-
web/app/components/share/chat/index.tsx | 19 +++-
.../share/text-generation/index.tsx | 10 +-
web/app/components/share/utils.ts | 18 ++++
web/service/base.ts | 12 ++-
web/service/share.ts | 6 ++
12 files changed, 161 insertions(+), 90 deletions(-)
create mode 100644 api/controllers/web/passport.py
create mode 100644 api/libs/passport.py
create mode 100644 web/app/components/share/utils.ts
diff --git a/.gitignore b/.gitignore
index 957c39cdd1..3a7773ccbe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -109,6 +109,7 @@ venv/
ENV/
env.bak/
venv.bak/
+.conda/
# Spyder project settings
.spyderproject
diff --git a/api/app.py b/api/app.py
index 3beb2cf706..1b67c0cbd2 100644
--- a/api/app.py
+++ b/api/app.py
@@ -155,7 +155,7 @@ def register_blueprints(app):
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
- allow_headers=['Content-Type', 'Authorization'],
+ allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py
index c419453665..0808dce5c4 100644
--- a/api/controllers/web/__init__.py
+++ b/api/controllers/web/__init__.py
@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
-from . import completion, app, conversation, message, site, saved_message, audio
+from . import completion, app, conversation, message, site, saved_message, audio, passport
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
new file mode 100644
index 0000000000..219f6e731f
--- /dev/null
+++ b/api/controllers/web/passport.py
@@ -0,0 +1,64 @@
+# -*- coding:utf-8 -*-
+import uuid
+from controllers.web import api
+from flask_restful import Resource
+from flask import request
+from werkzeug.exceptions import Unauthorized, NotFound
+from models.model import Site, EndUser, App
+from extensions.ext_database import db
+from libs.passport import PassportService
+
+class PassportResource(Resource):
+ """Base resource for passport."""
+ def get(self):
+ app_id = request.headers.get('X-App-Code')
+ if app_id is None:
+ raise Unauthorized('X-App-Code header is missing.')
+
+ # get site from db and check if it is normal
+ site = db.session.query(Site).filter(
+ Site.code == app_id,
+ Site.status == 'normal'
+ ).first()
+ if not site:
+ raise NotFound()
+ # get app from db and check if it is normal and enable_site
+ app_model = db.session.query(App).filter(App.id == site.app_id).first()
+ if not app_model or app_model.status != 'normal' or not app_model.enable_site:
+ raise NotFound()
+
+ end_user = EndUser(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ type='browser',
+ is_anonymous=True,
+ session_id=generate_session_id(),
+ )
+ db.session.add(end_user)
+ db.session.commit()
+
+ payload = {
+ "iss": site.app_id,
+ 'sub': 'Web API Passport',
+ 'app_id': site.app_id,
+ 'end_user_id': end_user.id,
+ }
+
+ tk = PassportService().issue(payload)
+
+ return {
+ 'access_token': tk,
+ }
+
+api.add_resource(PassportResource, '/passport')
+
+def generate_session_id():
+ """
+ Generate a unique session ID.
+ """
+ while True:
+ session_id = str(uuid.uuid4())
+ existing_count = db.session.query(EndUser) \
+ .filter(EndUser.session_id == session_id).count()
+ if existing_count == 0:
+ return session_id
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index 9321c427c2..4d48a190f2 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -1,110 +1,48 @@
# -*- coding:utf-8 -*-
-import uuid
from functools import wraps
-from flask import request, session
+from flask import request
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
-from models.model import App, Site, EndUser
+from models.model import App, EndUser
+from libs.passport import PassportService
-
-def validate_token(view=None):
+def validate_jwt_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
- site = validate_and_get_site()
-
- app_model = db.session.query(App).filter(App.id == site.app_id).first()
- if not app_model:
- raise NotFound()
-
- if app_model.status != 'normal':
- raise NotFound()
-
- if not app_model.enable_site:
- raise NotFound()
-
- end_user = create_or_update_end_user_for_session(app_model)
+ app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
return decorated
-
if view:
return decorator(view)
return decorator
-
-def validate_and_get_site():
- """
- Validate and get API token.
- """
+def decode_jwt_token():
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized('Authorization header is missing.')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
-
- auth_scheme, auth_token = auth_header.split(None, 1)
+
+ auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.')
-
- site = db.session.query(Site).filter(
- Site.code == auth_token,
- Site.status == 'normal'
- ).first()
-
- if not site:
+ decoded = PassportService().verify(tk)
+ app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
+ if not app_model:
+ raise NotFound()
+ end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
+ if not end_user:
raise NotFound()
- return site
-
-
-def create_or_update_end_user_for_session(app_model):
- """
- Create or update session terminal based on session ID.
- """
- if 'session_id' not in session:
- session['session_id'] = generate_session_id()
-
- session_id = session.get('session_id')
- end_user = db.session.query(EndUser) \
- .filter(
- EndUser.session_id == session_id,
- EndUser.type == 'browser'
- ).first()
-
- if end_user is None:
- end_user = EndUser(
- tenant_id=app_model.tenant_id,
- app_id=app_model.id,
- type='browser',
- is_anonymous=True,
- session_id=session_id
- )
- db.session.add(end_user)
- db.session.commit()
-
- return end_user
-
-
-def generate_session_id():
- """
- Generate a unique session ID.
- """
- count = 1
- session_id = ''
- while count != 0:
- session_id = str(uuid.uuid4())
- count = db.session.query(EndUser) \
- .filter(EndUser.session_id == session_id).count()
-
- return session_id
-
+ return app_model, end_user
class WebApiResource(Resource):
- method_decorators = [validate_token]
+ method_decorators = [validate_jwt_token]
diff --git a/api/libs/passport.py b/api/libs/passport.py
new file mode 100644
index 0000000000..c3bd9e566f
--- /dev/null
+++ b/api/libs/passport.py
@@ -0,0 +1,20 @@
+# -*- coding:utf-8 -*-
+import jwt
+from werkzeug.exceptions import Unauthorized
+from flask import current_app
+class PassportService:
+ def __init__(self):
+ self.sk = current_app.config.get('SECRET_KEY')
+
+ def issue(self, payload):
+ return jwt.encode(payload, self.sk, algorithm='HS256')
+
+ def verify(self, token):
+ try:
+ return jwt.decode(token, self.sk, algorithms=['HS256'])
+ except jwt.exceptions.InvalidSignatureError:
+ raise Unauthorized('Invalid token signature.')
+ except jwt.exceptions.DecodeError:
+ raise Unauthorized('Invalid token.')
+ except jwt.exceptions.ExpiredSignatureError:
+ raise Unauthorized('Token has expired.')
diff --git a/api/requirements.txt b/api/requirements.txt
index e129eb4b37..fe554653b0 100644
--- a/api/requirements.txt
+++ b/api/requirements.txt
@@ -32,4 +32,5 @@ redis~=4.5.4
openpyxl==3.1.2
chardet~=5.1.0
docx2txt==0.8
-pypdfium2==4.16.0
\ No newline at end of file
+pypdfium2==4.16.0
+pyjwt~=2.6.0
\ No newline at end of file
diff --git a/web/app/components/share/chat/index.tsx b/web/app/components/share/chat/index.tsx
index 6b81454772..454102ba8b 100644
--- a/web/app/components/share/chat/index.tsx
+++ b/web/app/components/share/chat/index.tsx
@@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
import produce from 'immer'
import { useBoolean, useGetState } from 'ahooks'
import AppUnavailable from '../../base/app-unavailable'
+import { checkOrSetAccessToken } from '../utils'
import useConversation from './hooks/use-conversation'
import s from './style.module.css'
import { ToastContext } from '@/app/components/base/toast'
import Sidebar from '@/app/components/share/chat/sidebar'
import ConfigSence from '@/app/components/share/chat/config-scence'
import Header from '@/app/components/share/header'
-import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
+import {
+ delConversation,
+ fetchAppInfo,
+ fetchAppParams,
+ fetchChatList,
+ fetchConversations,
+ fetchSuggestedQuestions,
+ pinConversation,
+ sendChatMessage,
+ stopChatMessageResponding,
+ unpinConversation,
+ updateFeedback,
+} from '@/service/share'
import type { ConversationItem, SiteInfo } from '@/models/share'
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
@@ -296,7 +309,9 @@ const Main: FC = ({
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
}
- const fetchInitData = () => {
+ const fetchInitData = async () => {
+ await checkOrSetAccessToken()
+
return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,
diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx
index 5b2a077dcc..01bbcea5e1 100644
--- a/web/app/components/share/text-generation/index.tsx
+++ b/web/app/components/share/text-generation/index.tsx
@@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
import { XMarkIcon } from '@heroicons/react/24/outline'
import TabHeader from '../../base/tab-header'
import Button from '../../base/button'
+import { checkOrSetAccessToken } from '../utils'
import s from './style.module.css'
import RunBatch from './run-batch'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
@@ -76,9 +77,6 @@ const TextGeneration: FC = ({
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
setSavedMessages(res.data)
}
- useEffect(() => {
- fetchSavedMessage()
- }, [])
const handleSaveMessage = async (messageId: string) => {
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
notify({ type: 'success', message: t('common.api.saved') })
@@ -256,7 +254,9 @@ const TextGeneration: FC = ({
setAllTaskList(newAllTaskList)
}
- const fetchInitData = () => {
+ const fetchInitData = async () => {
+ await checkOrSetAccessToken()
+
return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,
@@ -267,7 +267,7 @@ const TextGeneration: FC = ({
},
plan: 'basic',
}
- : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
+ : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
}
useEffect(() => {
diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts
new file mode 100644
index 0000000000..7c10644e2a
--- /dev/null
+++ b/web/app/components/share/utils.ts
@@ -0,0 +1,18 @@
+import { fetchAccessToken } from '@/service/share'
+
+export const checkOrSetAccessToken = async () => {
+ const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
+ const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+ let accessTokenJson = { [sharedToken]: '' }
+ try {
+ accessTokenJson = JSON.parse(accessToken)
+ }
+ catch (e) {
+
+ }
+ if (!accessTokenJson[sharedToken]) {
+ const res = await fetchAccessToken(sharedToken)
+ accessTokenJson[sharedToken] = res.access_token
+ localStorage.setItem('token', JSON.stringify(accessTokenJson))
+ }
+}
diff --git a/web/service/base.ts b/web/service/base.ts
index dfdfae5331..019c8d5d0e 100644
--- a/web/service/base.ts
+++ b/web/service/base.ts
@@ -142,7 +142,15 @@ const baseFetch = (
const options = Object.assign({}, baseOptions, fetchOptions)
if (isPublicAPI) {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
- options.headers.set('Authorization', `bearer ${sharedToken}`)
+ const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+ let accessTokenJson = { [sharedToken]: '' }
+ try {
+ accessTokenJson = JSON.parse(accessToken)
+ }
+ catch (e) {
+
+ }
+ options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
}
if (deleteContentType) {
@@ -194,7 +202,7 @@ const baseFetch = (
case 401: {
if (isPublicAPI) {
Toast.notify({ type: 'error', message: 'Invalid token' })
- return
+ return bodyJson.then((data: any) => Promise.reject(data))
}
const loginUrl = `${globalThis.location.origin}/signin`
if (IS_CE_EDITION) {
diff --git a/web/service/share.ts b/web/service/share.ts
index df01abc3eb..abdba8cc05 100644
--- a/web/service/share.ts
+++ b/web/service/share.ts
@@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
}
+
+export const fetchAccessToken = async (appCode: string) => {
+ const headers = new Headers()
+ headers.append('X-App-Code', appCode)
+ return get('/passport', { headers }) as Promise<{ access_token: string }>
+}