From 08785219130d7752c78042f8ca9ca7540589a13c Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 27 Apr 2024 16:48:09 +0800 Subject: [PATCH] feat: use flask-restx instead of flask-restful in service API which allow OpenAPI schema file generate --- api/controllers/service_api/__init__.py | 11 +- api/controllers/service_api/app/__init__.py | 3 + api/controllers/service_api/app/app.py | 6 +- api/controllers/service_api/app/audio.py | 4 +- api/controllers/service_api/app/completion.py | 4 +- .../service_api/app/conversation.py | 6 +- api/controllers/service_api/app/file.py | 4 +- api/controllers/service_api/app/message.py | 6 +- api/controllers/service_api/app/workflow.py | 4 +- .../service_api/dataset/__init__.py | 3 + .../service_api/dataset/dataset.py | 4 +- .../service_api/dataset/document.py | 4 +- .../service_api/dataset/segment.py | 4 +- api/controllers/service_api/index.py | 16 --- api/controllers/service_api/wraps.py | 2 +- api/fields/conversation_fields.py | 2 +- api/fields/file_fields.py | 2 +- api/libs/flask_restx_external_api.py | 115 ++++++++++++++++++ api/libs/helper.py | 2 +- api/requirements.txt | 1 + 20 files changed, 156 insertions(+), 47 deletions(-) delete mode 100644 api/controllers/service_api/index.py create mode 100644 api/libs/flask_restx_external_api.py diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a891..854e1eccac 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,11 +1,14 @@ from flask import Blueprint -from libs.external_api import ExternalApi +from libs.flask_restx_external_api import FlaskRestxExternalApi + +from .app import api as app_ns +from .dataset import api as dataset_ns bp = Blueprint('service_api', __name__, url_prefix='/v1') -api = ExternalApi(bp) +api = FlaskRestxExternalApi(bp, doc='/docs/', title='Dify OpenAPI', version='1.0', description='Dify OpenAPI') +api.add_namespace(app_ns) +api.add_namespace(dataset_ns) - -from . import index from .app import app, audio, completion, conversation, file, message, workflow from .dataset import dataset, document, segment diff --git a/api/controllers/service_api/app/__init__.py b/api/controllers/service_api/app/__init__.py index e69de29bb2..39ed0a2ff3 100644 --- a/api/controllers/service_api/app/__init__.py +++ b/api/controllers/service_api/app/__init__.py @@ -0,0 +1,3 @@ +from flask_restx import Namespace + +api = Namespace('app', path='/', description='App Service API') diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index bccce9b55b..0c6cea64eb 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,8 +1,8 @@ from flask import current_app -from flask_restful import Resource, fields, marshal_with +from flask_restx import Resource, fields, marshal_with -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from models.model import App, AppMode @@ -96,7 +96,7 @@ class AppInfoApi(Resource): return { 'name':app_model.name, 'description':app_model.description - } + } api.add_resource(AppParameterApi, '/parameters') diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 15c0a153b8..b4f3b5fb3e 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import ( AppUnavailableError, AudioTooLargeError, diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c1fdf249bb..7876653bda 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,10 +1,10 @@ import logging -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 02158f8b56..693d9957a1 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,9 +1,9 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound import services -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1b..522f916b32 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,8 +1,8 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restx import Resource, marshal_with import services -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import ( FileTooLargeError, NoFileUploadedError, diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 08f382b0a7..535d0b1b27 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,11 +1,11 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 2830530db5..b18ff598ec 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,9 +1,9 @@ import logging -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError -from controllers.service_api import api +from controllers.service_api.app import api from controllers.service_api.app.error import ( CompletionRequestError, NotWorkflowAppError, diff --git a/api/controllers/service_api/dataset/__init__.py b/api/controllers/service_api/dataset/__init__.py index e69de29bb2..f2e93cb63d 100644 --- a/api/controllers/service_api/dataset/__init__.py +++ b/api/controllers/service_api/dataset/__init__.py @@ -0,0 +1,3 @@ +from flask_restx import Namespace + +api = Namespace('dataset', path='/', description='Dataset Service API') \ No newline at end of file diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index bf08291d7b..8490854eff 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,8 +1,8 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse import services.dataset_service -from controllers.service_api import api +from controllers.service_api.dataset import api from controllers.service_api.dataset.error import DatasetNameDuplicateError from controllers.service_api.wraps import DatasetApiResource from core.model_runtime.entities.model_entities import ModelType diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a6..0ed7e12adf 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,13 +1,13 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from sqlalchemy import desc from werkzeug.exceptions import NotFound import services.dataset_service -from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.dataset import api from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0849eb72ba..d569468518 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,9 +1,9 @@ from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.dataset import api from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_knowledge_limit_check, diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py deleted file mode 100644 index 932388b562..0000000000 --- a/api/controllers/service_api/index.py +++ /dev/null @@ -1,16 +0,0 @@ -from flask import current_app -from flask_restful import Resource - -from controllers.service_api import api - - -class IndexApi(Resource): - def get(self): - return { - "welcome": "Dify OpenAPI", - "api_version": "v1", - "server_version": current_app.config['CURRENT_VERSION'] - } - - -api.add_resource(IndexApi, '/') diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 8ae81531ae..2dea836732 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -6,7 +6,7 @@ from typing import Optional from flask import current_app, request from flask_login import user_logged_in -from flask_restful import Resource +from flask_restx import Resource from pydantic import BaseModel from werkzeug.exceptions import Forbidden, NotFound, Unauthorized diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 79ceb02685..ee5c503ed9 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc..6d7b6edaa7 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/libs/flask_restx_external_api.py b/api/libs/flask_restx_external_api.py new file mode 100644 index 0000000000..8070176302 --- /dev/null +++ b/api/libs/flask_restx_external_api.py @@ -0,0 +1,115 @@ +import re +import sys + +from flask import current_app, got_request_exception +from flask_restx import Api +from werkzeug.datastructures import Headers +from werkzeug.exceptions import HTTPException +from werkzeug.http import HTTP_STATUS_CODES + + +class FlaskRestxExternalApi(Api): + + def handle_error(self, e): + """Error handler for the API transforms a raised exception into a Flask + response, with the appropriate HTTP status code and body. + + :param e: the raised Exception object + :type e: Exception + + """ + got_request_exception.send(current_app, exception=e) + + headers = Headers() + if isinstance(e, HTTPException): + if e.response is not None: + resp = e.get_response() + return resp + + status_code = e.code + default_data = { + 'code': re.sub(r'(?= 500: + exc_info = sys.exc_info() + if exc_info[1] is None: + exc_info = None + current_app.log_exception(exc_info) + + if status_code == 406 and self.default_mediatype is None: + # if we are handling NotAcceptable (406), make sure that + # make_response uses a representation we support as the + # default mediatype (so that make_response doesn't throw + # another NotAcceptable error). + supported_mediatypes = list(self.representations.keys()) # only supported application/json + fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" + data = { + 'code': 'not_acceptable', + 'message': data.get('message') + } + resp = self.make_response( + data, + status_code, + headers, + fallback_mediatype = fallback_mediatype + ) + elif status_code == 400: + if isinstance(data.get('message'), dict): + param_key, param_value = list(data.get('message').items())[0] + data = { + 'code': 'invalid_param', + 'message': param_value, + 'params': param_key + } + else: + if 'code' not in data: + data['code'] = 'unknown' + + resp = self.make_response(data, status_code, headers) + else: + if 'code' not in data: + data['code'] = 'unknown' + + resp = self.make_response(data, status_code, headers) + + if status_code == 401: + resp = self.unauthorized(resp) + return resp + + def render_root(self): + return { + "welcome": "Dify OpenAPI", + "api_version": "v1", + "server_version": current_app.config['CURRENT_VERSION'] + } diff --git a/api/libs/helper.py b/api/libs/helper.py index f9cf590b7a..45db5e099e 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,7 +11,7 @@ from typing import Union from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restx import fields def run(script): diff --git a/api/requirements.txt b/api/requirements.txt index 5984405a41..1bef96d6af 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -6,6 +6,7 @@ Flask-Compress~=1.14 flask-login~=0.6.3 flask-migrate~=4.0.5 flask-restful~=0.3.10 +flask-restx~=1.3.0 flask-cors~=4.0.0 gunicorn~=22.0.0 gevent~=23.9.1