feat: use flask-restx instead of flask-restful in service API which allow OpenAPI schema file generate

This commit is contained in:
takatost 2024-04-27 16:48:09 +08:00
parent 661b30784e
commit 0878521913
20 changed files with 156 additions and 47 deletions

View File

@ -1,11 +1,14 @@
from flask import Blueprint from flask import Blueprint
from libs.external_api import ExternalApi from libs.flask_restx_external_api import FlaskRestxExternalApi
from .app import api as app_ns
from .dataset import api as dataset_ns
bp = Blueprint('service_api', __name__, url_prefix='/v1') bp = Blueprint('service_api', __name__, url_prefix='/v1')
api = ExternalApi(bp) api = FlaskRestxExternalApi(bp, doc='/docs/', title='Dify OpenAPI', version='1.0', description='Dify OpenAPI')
api.add_namespace(app_ns)
api.add_namespace(dataset_ns)
from . import index
from .app import app, audio, completion, conversation, file, message, workflow from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment from .dataset import dataset, document, segment

View File

@ -0,0 +1,3 @@
from flask_restx import Namespace
api = Namespace('app', path='/', description='App Service API')

View File

@ -1,8 +1,8 @@
from flask import current_app from flask import current_app
from flask_restful import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from models.model import App, AppMode from models.model import App, AppMode

View File

@ -1,11 +1,11 @@
import logging import logging
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,

View File

@ -1,10 +1,10 @@
import logging import logging
from flask_restful import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,

View File

@ -1,9 +1,9 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -1,8 +1,8 @@
from flask import request from flask import request
from flask_restful import Resource, marshal_with from flask_restx import Resource, marshal_with
import services import services
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileTooLargeError, FileTooLargeError,
NoFileUploadedError, NoFileUploadedError,

View File

@ -1,11 +1,11 @@
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -1,9 +1,9 @@
import logging import logging
from flask_restful import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.service_api import api from controllers.service_api.app import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
CompletionRequestError, CompletionRequestError,
NotWorkflowAppError, NotWorkflowAppError,

View File

@ -0,0 +1,3 @@
from flask_restx import Namespace
api = Namespace('dataset', path='/', description='Dataset Service API')

View File

@ -1,8 +1,8 @@
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restx import marshal, reqparse
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api.dataset import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType

View File

@ -1,13 +1,13 @@
import json import json
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restx import marshal, reqparse
from sqlalchemy import desc from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset import api
from controllers.service_api.dataset.error import ( from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError, ArchivedDocumentImmutableError,
DocumentIndexingError, DocumentIndexingError,

View File

@ -1,9 +1,9 @@
from flask_login import current_user from flask_login import current_user
from flask_restful import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset import api
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
DatasetApiResource, DatasetApiResource,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,

View File

@ -1,16 +0,0 @@
from flask import current_app
from flask_restful import Resource
from controllers.service_api import api
class IndexApi(Resource):
def get(self):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
"server_version": current_app.config['CURRENT_VERSION']
}
api.add_resource(IndexApi, '/')

View File

@ -6,7 +6,7 @@ from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restful import Resource from flask_restx import Resource
from pydantic import BaseModel from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from werkzeug.exceptions import Forbidden, NotFound, Unauthorized

View File

@ -1,4 +1,4 @@
from flask_restful import fields from flask_restx import fields
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
from libs.helper import TimestampField from libs.helper import TimestampField

View File

@ -1,4 +1,4 @@
from flask_restful import fields from flask_restx import fields
from libs.helper import TimestampField from libs.helper import TimestampField

View File

@ -0,0 +1,115 @@
import re
import sys
from flask import current_app, got_request_exception
from flask_restx import Api
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
from werkzeug.http import HTTP_STATUS_CODES
class FlaskRestxExternalApi(Api):
def handle_error(self, e):
"""Error handler for the API transforms a raised exception into a Flask
response, with the appropriate HTTP status code and body.
:param e: the raised Exception object
:type e: Exception
"""
got_request_exception.send(current_app, exception=e)
headers = Headers()
if isinstance(e, HTTPException):
if e.response is not None:
resp = e.get_response()
return resp
status_code = e.code
default_data = {
'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(),
'message': getattr(e, 'description', HTTP_STATUS_CODES.get(status_code, '')),
'status': status_code
}
if default_data['message'] and default_data['message'] == 'Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)':
default_data['message'] = 'Invalid JSON payload received or JSON payload is empty.'
headers = e.get_response().headers
elif isinstance(e, ValueError):
status_code = 400
default_data = {
'code': 'invalid_param',
'message': str(e),
'status': status_code
}
else:
status_code = 500
default_data = {
'message': HTTP_STATUS_CODES.get(status_code, ''),
}
# Werkzeug exceptions generate a content-length header which is added
# to the response in addition to the actual content-length header
# https://github.com/flask-restful/flask-restful/issues/534
remove_headers = ('Content-Length',)
for header in remove_headers:
headers.pop(header, None)
data = getattr(e, 'data', default_data)
# record the exception in the logs when we have a server error of status code: 500
if status_code and status_code >= 500:
exc_info = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
if status_code == 406 and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys()) # only supported application/json
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
data = {
'code': 'not_acceptable',
'message': data.get('message')
}
resp = self.make_response(
data,
status_code,
headers,
fallback_mediatype = fallback_mediatype
)
elif status_code == 400:
if isinstance(data.get('message'), dict):
param_key, param_value = list(data.get('message').items())[0]
data = {
'code': 'invalid_param',
'message': param_value,
'params': param_key
}
else:
if 'code' not in data:
data['code'] = 'unknown'
resp = self.make_response(data, status_code, headers)
else:
if 'code' not in data:
data['code'] = 'unknown'
resp = self.make_response(data, status_code, headers)
if status_code == 401:
resp = self.unauthorized(resp)
return resp
def render_root(self):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
"server_version": current_app.config['CURRENT_VERSION']
}

View File

@ -11,7 +11,7 @@ from typing import Union
from zoneinfo import available_timezones from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import fields from flask_restx import fields
def run(script): def run(script):

View File

@ -6,6 +6,7 @@ Flask-Compress~=1.14
flask-login~=0.6.3 flask-login~=0.6.3
flask-migrate~=4.0.5 flask-migrate~=4.0.5
flask-restful~=0.3.10 flask-restful~=0.3.10
flask-restx~=1.3.0
flask-cors~=4.0.0 flask-cors~=4.0.0
gunicorn~=22.0.0 gunicorn~=22.0.0
gevent~=23.9.1 gevent~=23.9.1