mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
chore(api/controllers): Apply Ruff Formatter. (#7645)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
parent
7ae728a9a3
commit
13be84e4d4
|
@ -1,3 +1 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
bp = Blueprint('console', __name__, url_prefix='/console/api')
|
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
# Import other controllers
|
# Import other controllers
|
||||||
|
|
|
@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||||
def admin_required(view):
|
def admin_required(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not os.getenv('ADMIN_API_KEY'):
|
if not os.getenv("ADMIN_API_KEY"):
|
||||||
raise Unauthorized('API key is invalid.')
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
auth_header = request.headers.get('Authorization')
|
auth_header = request.headers.get("Authorization")
|
||||||
if auth_header is None:
|
if auth_header is None:
|
||||||
raise Unauthorized('Authorization header is missing.')
|
raise Unauthorized("Authorization header is missing.")
|
||||||
|
|
||||||
if ' ' not in auth_header:
|
if " " not in auth_header:
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
auth_scheme = auth_scheme.lower()
|
auth_scheme = auth_scheme.lower()
|
||||||
|
|
||||||
if auth_scheme != 'bearer':
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
if os.getenv('ADMIN_API_KEY') != auth_token:
|
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||||
raise Unauthorized('API key is invalid.')
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource):
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('desc', type=str, location='json')
|
parser.add_argument("desc", type=str, location="json")
|
||||||
parser.add_argument('copyright', type=str, location='json')
|
parser.add_argument("copyright", type=str, location="json")
|
||||||
parser.add_argument('privacy_policy', type=str, location='json')
|
parser.add_argument("privacy_policy", type=str, location="json")
|
||||||
parser.add_argument('custom_disclaimer', type=str, location='json')
|
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||||
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
|
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
|
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = App.query.filter(App.id == args['app_id']).first()
|
app = App.query.filter(App.id == args["app_id"]).first()
|
||||||
if not app:
|
if not app:
|
||||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||||
|
|
||||||
site = app.site
|
site = app.site
|
||||||
if not site:
|
if not site:
|
||||||
desc = args['desc'] if args['desc'] else ''
|
desc = args["desc"] if args["desc"] else ""
|
||||||
copy_right = args['copyright'] if args['copyright'] else ''
|
copy_right = args["copyright"] if args["copyright"] else ""
|
||||||
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
|
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||||
custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||||
else:
|
else:
|
||||||
desc = site.description if site.description else \
|
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||||
args['desc'] if args['desc'] else ''
|
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||||
copy_right = site.copyright if site.copyright else \
|
privacy_policy = (
|
||||||
args['copyright'] if args['copyright'] else ''
|
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||||
privacy_policy = site.privacy_policy if site.privacy_policy else \
|
)
|
||||||
args['privacy_policy'] if args['privacy_policy'] else ''
|
custom_disclaimer = (
|
||||||
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
|
site.custom_disclaimer
|
||||||
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
if site.custom_disclaimer
|
||||||
|
else args["custom_disclaimer"]
|
||||||
|
if args["custom_disclaimer"]
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||||
|
|
||||||
if not recommended_app:
|
if not recommended_app:
|
||||||
recommended_app = RecommendedApp(
|
recommended_app = RecommendedApp(
|
||||||
|
@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource):
|
||||||
copyright=copy_right,
|
copyright=copy_right,
|
||||||
privacy_policy=privacy_policy,
|
privacy_policy=privacy_policy,
|
||||||
custom_disclaimer=custom_disclaimer,
|
custom_disclaimer=custom_disclaimer,
|
||||||
language=args['language'],
|
language=args["language"],
|
||||||
category=args['category'],
|
category=args["category"],
|
||||||
position=args['position']
|
position=args["position"],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(recommended_app)
|
db.session.add(recommended_app)
|
||||||
|
@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource):
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 201
|
return {"result": "success"}, 201
|
||||||
else:
|
else:
|
||||||
recommended_app.description = desc
|
recommended_app.description = desc
|
||||||
recommended_app.copyright = copy_right
|
recommended_app.copyright = copy_right
|
||||||
recommended_app.privacy_policy = privacy_policy
|
recommended_app.privacy_policy = privacy_policy
|
||||||
recommended_app.custom_disclaimer = custom_disclaimer
|
recommended_app.custom_disclaimer = custom_disclaimer
|
||||||
recommended_app.language = args['language']
|
recommended_app.language = args["language"]
|
||||||
recommended_app.category = args['category']
|
recommended_app.category = args["category"]
|
||||||
recommended_app.position = args['position']
|
recommended_app.position = args["position"]
|
||||||
|
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class InsertExploreAppApi(Resource):
|
class InsertExploreAppApi(Resource):
|
||||||
|
@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource):
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||||
if not recommended_app:
|
if not recommended_app:
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||||
if app:
|
if app:
|
||||||
app.is_public = False
|
app.is_public = False
|
||||||
|
|
||||||
installed_apps = InstalledApp.query.filter(
|
installed_apps = InstalledApp.query.filter(
|
||||||
InstalledApp.app_id == recommended_app.app_id,
|
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
for installed_app in installed_apps:
|
for installed_app in installed_apps:
|
||||||
|
@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource):
|
||||||
db.session.delete(recommended_app)
|
db.session.delete(recommended_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
|
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
|
||||||
api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')
|
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
|
||||||
|
|
|
@ -14,26 +14,21 @@ from .setup import setup_required
|
||||||
from .wraps import account_initialization_required
|
from .wraps import account_initialization_required
|
||||||
|
|
||||||
api_key_fields = {
|
api_key_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'type': fields.String,
|
"type": fields.String,
|
||||||
'token': fields.String,
|
"token": fields.String,
|
||||||
'last_used_at': TimestampField,
|
"last_used_at": TimestampField,
|
||||||
'created_at': TimestampField
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
api_key_list = {
|
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||||
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_resource(resource_id, tenant_id, resource_model):
|
def _get_resource(resource_id, tenant_id, resource_model):
|
||||||
resource = resource_model.query.filter_by(
|
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||||
id=resource_id, tenant_id=tenant_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if resource is None:
|
if resource is None:
|
||||||
flask_restful.abort(
|
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||||
404, message=f"{resource_model.__name__} not found.")
|
|
||||||
|
|
||||||
return resource
|
return resource
|
||||||
|
|
||||||
|
@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource):
|
||||||
@marshal_with(api_key_list)
|
@marshal_with(api_key_list)
|
||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id,
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
self.resource_model)
|
keys = (
|
||||||
keys = db.session.query(ApiToken). \
|
db.session.query(ApiToken)
|
||||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||||
all()
|
.all()
|
||||||
|
)
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id,
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
self.resource_model)
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = db.session.query(ApiToken). \
|
current_key_count = (
|
||||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
db.session.query(ApiToken)
|
||||||
count()
|
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
if current_key_count >= self.max_keys:
|
if current_key_count >= self.max_keys:
|
||||||
flask_restful.abort(
|
flask_restful.abort(
|
||||||
400,
|
400,
|
||||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
code='max_keys_exceeded'
|
code="max_keys_exceeded",
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
|
@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource):
|
||||||
def delete(self, resource_id, api_key_id):
|
def delete(self, resource_id, api_key_id):
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id,
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
self.resource_model)
|
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
key = db.session.query(ApiToken). \
|
key = (
|
||||||
filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
|
db.session.query(ApiToken)
|
||||||
first()
|
.filter(
|
||||||
|
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||||
|
ApiToken.type == self.resource_type,
|
||||||
|
ApiToken.id == api_key_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
flask_restful.abort(404, message='API key not found')
|
flask_restful.abort(404, message="API key not found")
|
||||||
|
|
||||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
resource_type = 'app'
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = 'app_id'
|
resource_id_field = "app_id"
|
||||||
token_prefix = 'app-'
|
token_prefix = "app-"
|
||||||
|
|
||||||
|
|
||||||
class AppApiKeyResource(BaseApiKeyResource):
|
class AppApiKeyResource(BaseApiKeyResource):
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
resource_type = 'app'
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = 'app_id'
|
resource_id_field = "app_id"
|
||||||
|
|
||||||
|
|
||||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
resource_type = 'dataset'
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = 'dataset_id'
|
resource_id_field = "dataset_id"
|
||||||
token_prefix = 'ds-'
|
token_prefix = "ds-"
|
||||||
|
|
||||||
|
|
||||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||||
|
|
||||||
def after_request(self, resp):
|
def after_request(self, resp):
|
||||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
return resp
|
return resp
|
||||||
resource_type = 'dataset'
|
|
||||||
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = 'dataset_id'
|
resource_id_field = "dataset_id"
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
|
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
|
||||||
api.add_resource(AppApiKeyResource,
|
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||||
'/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
|
||||||
api.add_resource(DatasetApiKeyListResource,
|
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||||
'/datasets/<uuid:resource_id>/api-keys')
|
|
||||||
api.add_resource(DatasetApiKeyResource,
|
|
||||||
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
|
||||||
|
|
|
@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
|
||||||
|
|
||||||
|
|
||||||
class AdvancedPromptTemplateList(Resource):
|
class AdvancedPromptTemplateList(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return AdvancedPromptTemplateService.get_prompt(args)
|
return AdvancedPromptTemplateService.get_prompt(args)
|
||||||
|
|
||||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
|
||||||
|
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")
|
||||||
|
|
|
@ -18,15 +18,12 @@ class AgentLogApi(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get agent logs"""
|
"""Get agent logs"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=uuid_value, required=True, location='args')
|
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
|
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return AgentService.get_agent_logs(
|
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||||
app_model,
|
|
||||||
args['conversation_id'],
|
|
||||||
args['message_id']
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')
|
|
||||||
|
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")
|
||||||
|
|
|
@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def post(self, app_id, action):
|
def post(self, app_id, action):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if action == 'enable':
|
if action == "enable":
|
||||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||||
elif action == 'disable':
|
elif action == "disable":
|
||||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported annotation reply action')
|
raise ValueError("Unsupported annotation reply action")
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||||
annotation_setting_id = str(annotation_setting_id)
|
annotation_setting_id = str(annotation_setting_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||||
|
@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def get(self, app_id, job_id, action):
|
def get(self, app_id, job_id, action):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job is not exist.")
|
||||||
|
|
||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ''
|
error_msg = ""
|
||||||
if job_status == 'error':
|
if job_status == "error":
|
||||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
||||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||||
|
|
||||||
return {
|
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||||
'job_id': job_id,
|
|
||||||
'job_status': job_status,
|
|
||||||
'error_msg': error_msg
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationListApi(Resource):
|
class AnnotationListApi(Resource):
|
||||||
|
@ -109,18 +105,18 @@ class AnnotationListApi(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
keyword = request.args.get('keyword', default=None, type=str)
|
keyword = request.args.get("keyword", default=None, type=str)
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||||
response = {
|
response = {
|
||||||
'data': marshal(annotation_list, annotation_fields),
|
"data": marshal(annotation_list, annotation_fields),
|
||||||
'has_more': len(annotation_list) == limit,
|
"has_more": len(annotation_list) == limit,
|
||||||
'limit': limit,
|
"limit": limit,
|
||||||
'total': total,
|
"total": total,
|
||||||
'page': page
|
"page": page,
|
||||||
}
|
}
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
@ -135,9 +131,7 @@ class AnnotationExportApi(Resource):
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||||
response = {
|
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||||
'data': marshal(annotation_list, annotation_fields)
|
|
||||||
}
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
|
@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource):
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('question', required=True, type=str, location='json')
|
parser.add_argument("question", required=True, type=str, location="json")
|
||||||
parser.add_argument('answer', required=True, type=str, location='json')
|
parser.add_argument("answer", required=True, type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||||
return annotation
|
return annotation
|
||||||
|
@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
|
@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('question', required=True, type=str, location='json')
|
parser.add_argument("question", required=True, type=str, location="json")
|
||||||
parser.add_argument('answer', required=True, type=str, location='json')
|
parser.add_argument("answer", required=True, type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||||
return annotation
|
return annotation
|
||||||
|
@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class AnnotationBatchImportApi(Resource):
|
class AnnotationBatchImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
# check file type
|
# check file type
|
||||||
if not file.filename.endswith('.csv'):
|
if not file.filename.endswith(".csv"):
|
||||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||||
|
|
||||||
|
@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
def get(self, app_id, job_id):
|
def get(self, app_id, job_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job is not exist.")
|
||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ''
|
error_msg = ""
|
||||||
if job_status == 'error':
|
if job_status == "error":
|
||||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
|
||||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||||
|
|
||||||
return {
|
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||||
'job_id': job_id,
|
|
||||||
'job_status': job_status,
|
|
||||||
'error_msg': error_msg
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationHitHistoryListApi(Resource):
|
class AnnotationHitHistoryListApi(Resource):
|
||||||
|
@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||||
page, limit)
|
app_id, annotation_id, page, limit
|
||||||
|
)
|
||||||
response = {
|
response = {
|
||||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||||
'has_more': len(annotation_hit_history_list) == limit,
|
"has_more": len(annotation_hit_history_list) == limit,
|
||||||
'limit': limit,
|
"limit": limit,
|
||||||
'total': total,
|
"total": total,
|
||||||
'page': page
|
"page": page,
|
||||||
}
|
}
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||||
api.add_resource(AnnotationReplyActionStatusApi,
|
api.add_resource(
|
||||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
)
|
||||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
|
||||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
|
||||||
|
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||||
|
|
|
@ -18,27 +18,35 @@ from libs.login import login_required
|
||||||
from services.app_dsl_service import AppDslService
|
from services.app_dsl_service import AppDslService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||||
|
|
||||||
|
|
||||||
class AppListApi(Resource):
|
class AppListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get app list"""
|
"""Get app list"""
|
||||||
|
|
||||||
def uuid_list(value):
|
def uuid_list(value):
|
||||||
try:
|
try:
|
||||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
abort(400, message="Invalid UUID format in tag_ids.")
|
abort(400, message="Invalid UUID format in tag_ids.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
parser.add_argument(
|
||||||
parser.add_argument('name', type=str, location='args', required=False)
|
"mode",
|
||||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
type=str,
|
||||||
|
choices=["chat", "workflow", "agent-chat", "channel", "all"],
|
||||||
|
default="all",
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument("name", type=str, location="args", required=False)
|
||||||
|
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -46,7 +54,7 @@ class AppListApi(Resource):
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
return marshal(app_pagination, app_pagination_fields)
|
return marshal(app_pagination, app_pagination_fields)
|
||||||
|
|
||||||
|
@ -54,23 +62,23 @@ class AppListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
@cloud_edition_billing_resource_check('apps')
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Create app"""
|
"""Create app"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument('description', type=str, location='json')
|
parser.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
|
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||||
parser.add_argument('icon_type', type=str, location='json')
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if 'mode' not in args or args['mode'] is None:
|
if "mode" not in args or args["mode"] is None:
|
||||||
raise BadRequest("mode is required")
|
raise BadRequest("mode is required")
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
|
@ -84,7 +92,7 @@ class AppImportApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
@cloud_edition_billing_resource_check('apps')
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Import app"""
|
"""Import app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
|
@ -92,19 +100,16 @@ class AppImportApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('name', type=str, location='json')
|
parser.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument('description', type=str, location='json')
|
parser.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument('icon_type', type=str, location='json')
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = AppDslService.import_and_create_new_app(
|
app = AppDslService.import_and_create_new_app(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
|
||||||
data=args['data'],
|
|
||||||
args=args,
|
|
||||||
account=current_user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
@cloud_edition_billing_resource_check('apps')
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Import app from url"""
|
"""Import app from url"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
|
@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('name', type=str, location='json')
|
parser.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument('description', type=str, location='json')
|
parser.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = AppDslService.import_and_create_new_app_from_url(
|
app = AppDslService.import_and_create_new_app_from_url(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
|
||||||
url=args['url'],
|
|
||||||
args=args,
|
|
||||||
account=current_user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
|
||||||
|
|
||||||
class AppApi(Resource):
|
class AppApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -167,12 +168,12 @@ class AppApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('description', type=str, location='json')
|
parser.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument('icon_type', type=str, location='json')
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
parser.add_argument('max_active_requests', type=int, location='json')
|
parser.add_argument("max_active_requests", type=int, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
|
@ -193,7 +194,7 @@ class AppApi(Resource):
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_service.delete_app(app_model)
|
app_service.delete_app(app_model)
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class AppCopyApi(Resource):
|
class AppCopyApi(Resource):
|
||||||
|
@ -209,19 +210,16 @@ class AppCopyApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, location='json')
|
parser.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument('description', type=str, location='json')
|
parser.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument('icon_type', type=str, location='json')
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||||
app = AppDslService.import_and_create_new_app(
|
app = AppDslService.import_and_create_new_app(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
|
||||||
data=data,
|
|
||||||
args=args,
|
|
||||||
account=current_user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
@ -240,12 +238,10 @@ class AppExportApi(Resource):
|
||||||
|
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||||
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AppNameApi(Resource):
|
class AppNameApi(Resource):
|
||||||
|
@ -260,11 +256,11 @@ class AppNameApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_name(app_model, args.get('name'))
|
app_model = app_service.update_app_name(app_model, args.get("name"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@ -281,12 +277,12 @@ class AppIconApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('icon', type=str, location='json')
|
parser.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
|
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@ -303,11 +299,11 @@ class AppSiteStatus(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('enable_site', type=bool, required=True, location='json')
|
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
|
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@ -324,11 +320,11 @@ class AppApiStatus(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('enable_api', type=bool, required=True, location='json')
|
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
|
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@ -339,9 +335,7 @@ class AppTraceApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
"""Get app trace"""
|
"""Get app trace"""
|
||||||
app_trace_config = OpsTraceManager.get_app_tracing_config(
|
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
|
||||||
app_id=app_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return app_trace_config
|
return app_trace_config
|
||||||
|
|
||||||
|
@ -353,27 +347,27 @@ class AppTraceApi(Resource):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('enabled', type=bool, required=True, location='json')
|
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
OpsTraceManager.update_app_tracing_config(
|
OpsTraceManager.update_app_tracing_config(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
enabled=args['enabled'],
|
enabled=args["enabled"],
|
||||||
tracing_provider=args['tracing_provider'],
|
tracing_provider=args["tracing_provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppListApi, '/apps')
|
api.add_resource(AppListApi, "/apps")
|
||||||
api.add_resource(AppImportApi, '/apps/import')
|
api.add_resource(AppImportApi, "/apps/import")
|
||||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
|
||||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||||
api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
|
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
|
||||||
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
|
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
|
||||||
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
|
||||||
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
|
||||||
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
|
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AudioService.transcript_asr(
|
response = AudioService.transcript_asr(
|
||||||
|
@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, location='json')
|
parser.add_argument("message_id", type=str, location="json")
|
||||||
parser.add_argument('text', type=str, location='json')
|
parser.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument("streaming", type=bool, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id', None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get('text', None)
|
text = args.get("text", None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (
|
||||||
|
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict
|
||||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
):
|
||||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
|
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
voice = (
|
||||||
'voice')
|
args.get("voice")
|
||||||
|
if args.get("voice")
|
||||||
|
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||||
app_model=app_model,
|
|
||||||
text=text,
|
|
||||||
message_id=message_id,
|
|
||||||
voice=voice
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
|
@ -145,12 +145,12 @@ class TextModesApi(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('language', type=str, required=True, location='args')
|
parser.add_argument("language", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
response = AudioService.transcript_tts_voices(
|
response = AudioService.transcript_tts_voices(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
language=args['language'],
|
language=args["language"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -179,6 +179,6 @@ class TextModesApi(Resource):
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
|
||||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
|
||||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")
|
||||||
|
|
|
@ -35,33 +35,28 @@ from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
# define completion message api for user
|
# define completion message api for user
|
||||||
class CompletionMessageApi(Resource):
|
class CompletionMessageApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, location='json', default='')
|
parser.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] != 'blocking'
|
streaming = args["response_mode"] != "blocking"
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
user=account,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -97,7 +92,7 @@ class CompletionMessageStopApi(Resource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageApi(Resource):
|
class ChatMessageApi(Resource):
|
||||||
|
@ -107,27 +102,23 @@ class ChatMessageApi(Resource):
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, required=True, location='json')
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] != 'blocking'
|
streaming = args["response_mode"] != "blocking"
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
account = flask_login.current_user
|
account = flask_login.current_user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||||
user=account,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -163,10 +154,10 @@ class ChatMessageStopApi(Resource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
|
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
|
||||||
api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
|
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||||
api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
|
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
|
||||||
api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')
|
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||||
|
|
|
@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationApi(Resource):
|
class CompletionConversationApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -36,24 +35,23 @@ class CompletionConversationApi(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('keyword', type=str, location='args')
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('annotation_status', type=str,
|
parser.add_argument(
|
||||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
)
|
||||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
|
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
|
||||||
|
|
||||||
if args['keyword']:
|
if args["keyword"]:
|
||||||
query = query.join(
|
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
|
||||||
Message, Message.conversation_id == Conversation.id
|
|
||||||
).filter(
|
|
||||||
or_(
|
or_(
|
||||||
Message.query.ilike('%{}%'.format(args['keyword'])),
|
Message.query.ilike("%{}%".format(args["keyword"])),
|
||||||
Message.answer.ilike('%{}%'.format(args['keyword']))
|
Message.answer.ilike("%{}%".format(args["keyword"])),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,8 +59,8 @@ class CompletionConversationApi(Resource):
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
|
@ -70,8 +68,8 @@ class CompletionConversationApi(Resource):
|
||||||
|
|
||||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=59)
|
end_datetime = end_datetime.replace(second=59)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
|
@ -79,29 +77,25 @@ class CompletionConversationApi(Resource):
|
||||||
|
|
||||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||||
|
|
||||||
if args['annotation_status'] == "annotated":
|
if args["annotation_status"] == "annotated":
|
||||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||||
)
|
)
|
||||||
elif args['annotation_status'] == "not_annotated":
|
elif args["annotation_status"] == "not_annotated":
|
||||||
query = query.outerjoin(
|
query = (
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
.group_by(Conversation.id)
|
||||||
|
.having(func.count(MessageAnnotation.id) == 0)
|
||||||
|
)
|
||||||
|
|
||||||
query = query.order_by(Conversation.created_at.desc())
|
query = query.order_by(Conversation.created_at.desc())
|
||||||
|
|
||||||
conversations = db.paginate(
|
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||||
query,
|
|
||||||
page=args['page'],
|
|
||||||
per_page=args['limit'],
|
|
||||||
error_out=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationDetailApi(Resource):
|
class CompletionConversationDetailApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
conversation = db.session.query(Conversation) \
|
conversation = (
|
||||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource):
|
||||||
conversation.is_deleted = True
|
conversation.is_deleted = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class ChatConversationApi(Resource):
|
class ChatConversationApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -146,22 +142,28 @@ class ChatConversationApi(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('keyword', type=str, location='args')
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('annotation_status', type=str,
|
parser.add_argument(
|
||||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||||
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
|
)
|
||||||
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
|
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
required=False, default='-updated_at', location='args')
|
parser.add_argument(
|
||||||
|
"sort_by",
|
||||||
|
type=str,
|
||||||
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
|
required=False,
|
||||||
|
default="-updated_at",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
subquery = (
|
subquery = (
|
||||||
db.session.query(
|
db.session.query(
|
||||||
Conversation.id.label('conversation_id'),
|
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||||
EndUser.session_id.label('from_end_user_session_id')
|
|
||||||
)
|
)
|
||||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||||
.subquery()
|
.subquery()
|
||||||
|
@ -169,28 +171,31 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
||||||
|
|
||||||
if args['keyword']:
|
if args["keyword"]:
|
||||||
keyword_filter = '%{}%'.format(args['keyword'])
|
keyword_filter = "%{}%".format(args["keyword"])
|
||||||
query = query.join(
|
query = (
|
||||||
Message, Message.conversation_id == Conversation.id,
|
query.join(
|
||||||
).join(
|
Message,
|
||||||
subquery, subquery.c.conversation_id == Conversation.id
|
Message.conversation_id == Conversation.id,
|
||||||
).filter(
|
)
|
||||||
|
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||||
|
.filter(
|
||||||
or_(
|
or_(
|
||||||
Message.query.ilike(keyword_filter),
|
Message.query.ilike(keyword_filter),
|
||||||
Message.answer.ilike(keyword_filter),
|
Message.answer.ilike(keyword_filter),
|
||||||
Conversation.name.ilike(keyword_filter),
|
Conversation.name.ilike(keyword_filter),
|
||||||
Conversation.introduction.ilike(keyword_filter),
|
Conversation.introduction.ilike(keyword_filter),
|
||||||
subquery.c.from_end_user_session_id.ilike(keyword_filter)
|
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
|
@ -198,8 +203,8 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=59)
|
end_datetime = end_datetime.replace(second=59)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
|
@ -207,50 +212,46 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||||
|
|
||||||
if args['annotation_status'] == "annotated":
|
if args["annotation_status"] == "annotated":
|
||||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||||
)
|
)
|
||||||
elif args['annotation_status'] == "not_annotated":
|
elif args["annotation_status"] == "not_annotated":
|
||||||
query = query.outerjoin(
|
query = (
|
||||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
.group_by(Conversation.id)
|
||||||
|
.having(func.count(MessageAnnotation.id) == 0)
|
||||||
|
)
|
||||||
|
|
||||||
if args['message_count_gte'] and args['message_count_gte'] >= 1:
|
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||||
query = (
|
query = (
|
||||||
query.options(joinedload(Conversation.messages))
|
query.options(joinedload(Conversation.messages))
|
||||||
.join(Message, Message.conversation_id == Conversation.id)
|
.join(Message, Message.conversation_id == Conversation.id)
|
||||||
.group_by(Conversation.id)
|
.group_by(Conversation.id)
|
||||||
.having(func.count(Message.id) >= args['message_count_gte'])
|
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||||
)
|
)
|
||||||
|
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||||
|
|
||||||
match args['sort_by']:
|
match args["sort_by"]:
|
||||||
case 'created_at':
|
case "created_at":
|
||||||
query = query.order_by(Conversation.created_at.asc())
|
query = query.order_by(Conversation.created_at.asc())
|
||||||
case '-created_at':
|
case "-created_at":
|
||||||
query = query.order_by(Conversation.created_at.desc())
|
query = query.order_by(Conversation.created_at.desc())
|
||||||
case 'updated_at':
|
case "updated_at":
|
||||||
query = query.order_by(Conversation.updated_at.asc())
|
query = query.order_by(Conversation.updated_at.asc())
|
||||||
case '-updated_at':
|
case "-updated_at":
|
||||||
query = query.order_by(Conversation.updated_at.desc())
|
query = query.order_by(Conversation.updated_at.desc())
|
||||||
case _:
|
case _:
|
||||||
query = query.order_by(Conversation.created_at.desc())
|
query = query.order_by(Conversation.created_at.desc())
|
||||||
|
|
||||||
conversations = db.paginate(
|
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||||
query,
|
|
||||||
page=args['page'],
|
|
||||||
per_page=args['limit'],
|
|
||||||
error_out=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
class ChatConversationDetailApi(Resource):
|
class ChatConversationDetailApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
conversation = db.session.query(Conversation) \
|
conversation = (
|
||||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource):
|
||||||
conversation.is_deleted = True
|
conversation.is_deleted = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
|
||||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
|
||||||
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
|
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||||
|
|
||||||
|
|
||||||
def _get_conversation(app_model, conversation_id):
|
def _get_conversation(app_model, conversation_id):
|
||||||
conversation = db.session.query(Conversation) \
|
conversation = (
|
||||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
|
||||||
@marshal_with(paginated_conversation_variable_fields)
|
@marshal_with(paginated_conversation_variable_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('conversation_id', type=str, location='args')
|
parser.add_argument("conversation_id", type=str, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
|
@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
|
||||||
.where(ConversationVariable.app_id == app_model.id)
|
.where(ConversationVariable.app_id == app_model.id)
|
||||||
.order_by(ConversationVariable.created_at)
|
.order_by(ConversationVariable.created_at)
|
||||||
)
|
)
|
||||||
if args['conversation_id']:
|
if args["conversation_id"]:
|
||||||
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
|
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||||
else:
|
else:
|
||||||
raise ValueError('conversation_id is required')
|
raise ValueError("conversation_id is required")
|
||||||
|
|
||||||
# NOTE: This is a temporary solution to avoid performance issues.
|
# NOTE: This is a temporary solution to avoid performance issues.
|
||||||
page = 1
|
page = 1
|
||||||
|
@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
|
||||||
rows = session.scalars(stmt).all()
|
rows = session.scalars(stmt).all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'page': page,
|
"page": page,
|
||||||
'limit': page_size,
|
"limit": page_size,
|
||||||
'total': len(rows),
|
"total": len(rows),
|
||||||
'has_more': False,
|
"has_more": False,
|
||||||
'data': [
|
"data": [
|
||||||
{
|
{
|
||||||
'created_at': row.created_at,
|
"created_at": row.created_at,
|
||||||
'updated_at': row.updated_at,
|
"updated_at": row.updated_at,
|
||||||
**row.to_variable().model_dump(),
|
**row.to_variable().model_dump(),
|
||||||
}
|
}
|
||||||
for row in rows
|
for row in rows
|
||||||
|
@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
|
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")
|
||||||
|
|
|
@ -2,116 +2,120 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class AppNotFoundError(BaseHTTPException):
|
class AppNotFoundError(BaseHTTPException):
|
||||||
error_code = 'app_not_found'
|
error_code = "app_not_found"
|
||||||
description = "App not found."
|
description = "App not found."
|
||||||
code = 404
|
code = 404
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotInitializeError(BaseHTTPException):
|
class ProviderNotInitializeError(BaseHTTPException):
|
||||||
error_code = 'provider_not_initialize'
|
error_code = "provider_not_initialize"
|
||||||
description = "No valid model provider credentials found. " \
|
description = (
|
||||||
|
"No valid model provider credentials found. "
|
||||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderQuotaExceededError(BaseHTTPException):
|
class ProviderQuotaExceededError(BaseHTTPException):
|
||||||
error_code = 'provider_quota_exceeded'
|
error_code = "provider_quota_exceeded"
|
||||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
description = (
|
||||||
|
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||||
error_code = 'model_currently_not_support'
|
error_code = "model_currently_not_support"
|
||||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ConversationCompletedError(BaseHTTPException):
|
class ConversationCompletedError(BaseHTTPException):
|
||||||
error_code = 'conversation_completed'
|
error_code = "conversation_completed"
|
||||||
description = "The conversation has ended. Please start a new conversation."
|
description = "The conversation has ended. Please start a new conversation."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AppUnavailableError(BaseHTTPException):
|
class AppUnavailableError(BaseHTTPException):
|
||||||
error_code = 'app_unavailable'
|
error_code = "app_unavailable"
|
||||||
description = "App unavailable, please check your app configurations."
|
description = "App unavailable, please check your app configurations."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestError(BaseHTTPException):
|
class CompletionRequestError(BaseHTTPException):
|
||||||
error_code = 'completion_request_error'
|
error_code = "completion_request_error"
|
||||||
description = "Completion request failed."
|
description = "Completion request failed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||||
error_code = 'app_more_like_this_disabled'
|
error_code = "app_more_like_this_disabled"
|
||||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class NoAudioUploadedError(BaseHTTPException):
|
class NoAudioUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_audio_uploaded'
|
error_code = "no_audio_uploaded"
|
||||||
description = "Please upload your audio."
|
description = "Please upload your audio."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AudioTooLargeError(BaseHTTPException):
|
class AudioTooLargeError(BaseHTTPException):
|
||||||
error_code = 'audio_too_large'
|
error_code = "audio_too_large"
|
||||||
description = "Audio size exceeded. {message}"
|
description = "Audio size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_audio_type'
|
error_code = "unsupported_audio_type"
|
||||||
description = "Audio type not allowed."
|
description = "Audio type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||||
error_code = 'provider_not_support_speech_to_text'
|
error_code = "provider_not_support_speech_to_text"
|
||||||
description = "Provider not support speech to text."
|
description = "Provider not support speech to text."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NoFileUploadedError(BaseHTTPException):
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_file_uploaded'
|
error_code = "no_file_uploaded"
|
||||||
description = "Please upload your file."
|
description = "Please upload your file."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = 'too_many_files'
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DraftWorkflowNotExist(BaseHTTPException):
|
class DraftWorkflowNotExist(BaseHTTPException):
|
||||||
error_code = 'draft_workflow_not_exist'
|
error_code = "draft_workflow_not_exist"
|
||||||
description = "Draft workflow need to be initialized."
|
description = "Draft workflow need to be initialized."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DraftWorkflowNotSync(BaseHTTPException):
|
class DraftWorkflowNotSync(BaseHTTPException):
|
||||||
error_code = 'draft_workflow_not_sync'
|
error_code = "draft_workflow_not_sync"
|
||||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TracingConfigNotExist(BaseHTTPException):
|
class TracingConfigNotExist(BaseHTTPException):
|
||||||
error_code = 'trace_config_not_exist'
|
error_code = "trace_config_not_exist"
|
||||||
description = "Trace config not exist."
|
description = "Trace config not exist."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TracingConfigIsExist(BaseHTTPException):
|
class TracingConfigIsExist(BaseHTTPException):
|
||||||
error_code = 'trace_config_is_exist'
|
error_code = "trace_config_is_exist"
|
||||||
description = "Trace config is exist."
|
description = "Trace config is exist."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TracingConfigCheckError(BaseHTTPException):
|
class TracingConfigCheckError(BaseHTTPException):
|
||||||
error_code = 'trace_config_check_error'
|
error_code = "trace_config_check_error"
|
||||||
description = "Invalid Credentials."
|
description = "Invalid Credentials."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
|
@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
|
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
|
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=account.current_tenant_id,
|
||||||
instruction=args['instruction'],
|
instruction=args["instruction"],
|
||||||
model_config=args['model_config'],
|
model_config=args["model_config"],
|
||||||
no_variable=args['no_variable'],
|
no_variable=args["no_variable"],
|
||||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
|
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||||
|
|
|
@ -33,9 +33,9 @@ from services.message_service import MessageService
|
||||||
|
|
||||||
class ChatMessageListApi(Resource):
|
class ChatMessageListApi(Resource):
|
||||||
message_infinite_scroll_pagination_fields = {
|
message_infinite_scroll_pagination_fields = {
|
||||||
'limit': fields.Integer,
|
"limit": fields.Integer,
|
||||||
'has_more': fields.Boolean,
|
"has_more": fields.Boolean,
|
||||||
'data': fields.List(fields.Nested(message_detail_fields))
|
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -45,55 +45,69 @@ class ChatMessageListApi(Resource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
conversation = db.session.query(Conversation).filter(
|
conversation = (
|
||||||
Conversation.id == args['conversation_id'],
|
db.session.query(Conversation)
|
||||||
Conversation.app_id == app_model.id
|
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||||
).first()
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
if args['first_id']:
|
if args["first_id"]:
|
||||||
first_message = db.session.query(Message) \
|
first_message = (
|
||||||
.filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
|
db.session.query(Message)
|
||||||
|
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not first_message:
|
if not first_message:
|
||||||
raise NotFound("First message not found")
|
raise NotFound("First message not found")
|
||||||
|
|
||||||
history_messages = db.session.query(Message).filter(
|
history_messages = (
|
||||||
|
db.session.query(Message)
|
||||||
|
.filter(
|
||||||
Message.conversation_id == conversation.id,
|
Message.conversation_id == conversation.id,
|
||||||
Message.created_at < first_message.created_at,
|
Message.created_at < first_message.created_at,
|
||||||
Message.id != first_message.id
|
Message.id != first_message.id,
|
||||||
) \
|
)
|
||||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
.order_by(Message.created_at.desc())
|
||||||
|
.limit(args["limit"])
|
||||||
|
.all()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
history_messages = (
|
||||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
db.session.query(Message)
|
||||||
|
.filter(Message.conversation_id == conversation.id)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
|
.limit(args["limit"])
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
has_more = False
|
has_more = False
|
||||||
if len(history_messages) == args['limit']:
|
if len(history_messages) == args["limit"]:
|
||||||
current_page_first_message = history_messages[-1]
|
current_page_first_message = history_messages[-1]
|
||||||
rest_count = db.session.query(Message).filter(
|
rest_count = (
|
||||||
|
db.session.query(Message)
|
||||||
|
.filter(
|
||||||
Message.conversation_id == conversation.id,
|
Message.conversation_id == conversation.id,
|
||||||
Message.created_at < current_page_first_message.created_at,
|
Message.created_at < current_page_first_message.created_at,
|
||||||
Message.id != current_page_first_message.id
|
Message.id != current_page_first_message.id,
|
||||||
).count()
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
if rest_count > 0:
|
if rest_count > 0:
|
||||||
has_more = True
|
has_more = True
|
||||||
|
|
||||||
history_messages = list(reversed(history_messages))
|
history_messages = list(reversed(history_messages))
|
||||||
|
|
||||||
return InfiniteScrollPagination(
|
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||||
data=history_messages,
|
|
||||||
limit=args['limit'],
|
|
||||||
has_more=has_more
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackApi(Resource):
|
class MessageFeedbackApi(Resource):
|
||||||
|
@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = str(args['message_id'])
|
message_id = str(args["message_id"])
|
||||||
|
|
||||||
message = db.session.query(Message).filter(
|
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||||
Message.id == message_id,
|
|
||||||
Message.app_id == app_model.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
feedback = message.admin_feedback
|
feedback = message.admin_feedback
|
||||||
|
|
||||||
if not args['rating'] and feedback:
|
if not args["rating"] and feedback:
|
||||||
db.session.delete(feedback)
|
db.session.delete(feedback)
|
||||||
elif args['rating'] and feedback:
|
elif args["rating"] and feedback:
|
||||||
feedback.rating = args['rating']
|
feedback.rating = args["rating"]
|
||||||
elif not args['rating'] and not feedback:
|
elif not args["rating"] and not feedback:
|
||||||
raise ValueError('rating cannot be None when feedback not exists')
|
raise ValueError("rating cannot be None when feedback not exists")
|
||||||
else:
|
else:
|
||||||
feedback = MessageFeedback(
|
feedback = MessageFeedback(
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
conversation_id=message.conversation_id,
|
conversation_id=message.conversation_id,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
rating=args['rating'],
|
rating=args["rating"],
|
||||||
from_source='admin',
|
from_source="admin",
|
||||||
from_account_id=current_user.id
|
from_account_id=current_user.id,
|
||||||
)
|
)
|
||||||
db.session.add(feedback)
|
db.session.add(feedback)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class MessageAnnotationApi(Resource):
|
class MessageAnnotationApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('annotation')
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
|
@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||||
parser.add_argument('question', required=True, type=str, location='json')
|
parser.add_argument("question", required=True, type=str, location="json")
|
||||||
parser.add_argument('answer', required=True, type=str, location='json')
|
parser.add_argument("answer", required=True, type=str, location="json")
|
||||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||||
|
|
||||||
|
@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
count = db.session.query(MessageAnnotation).filter(
|
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
|
||||||
MessageAnnotation.app_id == app_model.id
|
|
||||||
).count()
|
|
||||||
|
|
||||||
return {'count': count}
|
return {"count": count}
|
||||||
|
|
||||||
|
|
||||||
class MessageSuggestedQuestionApi(Resource):
|
class MessageSuggestedQuestionApi(Resource):
|
||||||
|
@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model,
|
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||||
message_id=message_id,
|
|
||||||
user=current_user,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER
|
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message not found")
|
raise NotFound("Message not found")
|
||||||
|
@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
return {'data': questions}
|
return {"data": questions}
|
||||||
|
|
||||||
|
|
||||||
class MessageApi(Resource):
|
class MessageApi(Resource):
|
||||||
|
@ -221,10 +227,7 @@ class MessageApi(Resource):
|
||||||
def get(self, app_model, message_id):
|
def get(self, app_model, message_id):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).filter(
|
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||||
Message.id == message_id,
|
|
||||||
Message.app_id == app_model.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
@ -232,9 +235,9 @@ class MessageApi(Resource):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
|
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||||
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
|
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
|
||||||
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
|
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
|
||||||
api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
|
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||||
api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
|
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
|
||||||
api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')
|
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")
|
||||||
|
|
|
@ -19,19 +19,15 @@ from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigResource(Resource):
|
class ModelConfigResource(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
|
|
||||||
"""Modify app model config"""
|
"""Modify app model config"""
|
||||||
# validate config
|
# validate config
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
|
||||||
config=request.json,
|
|
||||||
app_mode=AppMode.value_of(app_model.mode)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_app_model_config = AppModelConfig(
|
new_app_model_config = AppModelConfig(
|
||||||
|
@ -41,15 +37,15 @@ class ModelConfigResource(Resource):
|
||||||
|
|
||||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||||
# get original app model config
|
# get original app model config
|
||||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
original_app_model_config: AppModelConfig = (
|
||||||
AppModelConfig.id == app_model.app_model_config_id
|
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||||
).first()
|
)
|
||||||
agent_mode = original_app_model_config.agent_mode_dict
|
agent_mode = original_app_model_config.agent_mode_dict
|
||||||
# decrypt agent tool parameters if it's secret-input
|
# decrypt agent tool parameters if it's secret-input
|
||||||
parameter_map = {}
|
parameter_map = {}
|
||||||
masked_parameter_map = {}
|
masked_parameter_map = {}
|
||||||
tool_map = {}
|
tool_map = {}
|
||||||
for tool in agent_mode.get('tools') or []:
|
for tool in agent_mode.get("tools") or []:
|
||||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -66,7 +62,7 @@ class ModelConfigResource(Resource):
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
identity_id=f'AGENT.{app_model.id}'
|
identity_id=f"AGENT.{app_model.id}",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue
|
continue
|
||||||
|
@ -79,18 +75,18 @@ class ModelConfigResource(Resource):
|
||||||
parameters = {}
|
parameters = {}
|
||||||
masked_parameter = {}
|
masked_parameter = {}
|
||||||
|
|
||||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||||
masked_parameter_map[key] = masked_parameter
|
masked_parameter_map[key] = masked_parameter
|
||||||
parameter_map[key] = parameters
|
parameter_map[key] = parameters
|
||||||
tool_map[key] = tool_runtime
|
tool_map[key] = tool_runtime
|
||||||
|
|
||||||
# encrypt agent tool parameters if it's secret-input
|
# encrypt agent tool parameters if it's secret-input
|
||||||
agent_mode = new_app_model_config.agent_mode_dict
|
agent_mode = new_app_model_config.agent_mode_dict
|
||||||
for tool in agent_mode.get('tools') or []:
|
for tool in agent_mode.get("tools") or []:
|
||||||
agent_tool_entity = AgentToolEntity(**tool)
|
agent_tool_entity = AgentToolEntity(**tool)
|
||||||
|
|
||||||
# get tool
|
# get tool
|
||||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||||
if key in tool_map:
|
if key in tool_map:
|
||||||
tool_runtime = tool_map[key]
|
tool_runtime = tool_map[key]
|
||||||
else:
|
else:
|
||||||
|
@ -108,7 +104,7 @@ class ModelConfigResource(Resource):
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
identity_id=f'AGENT.{app_model.id}'
|
identity_id=f"AGENT.{app_model.id}",
|
||||||
)
|
)
|
||||||
manager.delete_tool_parameters_cache()
|
manager.delete_tool_parameters_cache()
|
||||||
|
|
||||||
|
@ -118,13 +114,15 @@ class ModelConfigResource(Resource):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||||
if masked_key in agent_tool_entity.tool_parameters and \
|
if (
|
||||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
masked_key in agent_tool_entity.tool_parameters
|
||||||
|
and agent_tool_entity.tool_parameters[masked_key] == masked_value
|
||||||
|
):
|
||||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||||
|
|
||||||
# encrypt parameters
|
# encrypt parameters
|
||||||
if agent_tool_entity.tool_parameters:
|
if agent_tool_entity.tool_parameters:
|
||||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||||
|
|
||||||
# update app model config
|
# update app model config
|
||||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||||
|
@ -135,12 +133,9 @@ class ModelConfigResource(Resource):
|
||||||
app_model.app_model_config_id = new_app_model_config.id
|
app_model.app_model_config_id = new_app_model_config.id
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
app_model_config_was_updated.send(
|
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||||
app_model,
|
|
||||||
app_model_config=new_app_model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')
|
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")
|
||||||
|
|
|
@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trace_config = OpsService.get_tracing_app_config(
|
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||||
app_id=app_id, tracing_provider=args['tracing_provider']
|
|
||||||
)
|
|
||||||
if not trace_config:
|
if not trace_config:
|
||||||
return {"has_not_configured": True}
|
return {"has_not_configured": True}
|
||||||
return trace_config
|
return trace_config
|
||||||
|
@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource):
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
"""Create a new trace app configuration"""
|
"""Create a new trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.create_tracing_app_config(
|
result = OpsService.create_tracing_app_config(
|
||||||
app_id=app_id,
|
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||||
tracing_provider=args['tracing_provider'],
|
|
||||||
tracing_config=args['tracing_config']
|
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigIsExist()
|
raise TracingConfigIsExist()
|
||||||
if result.get('error'):
|
if result.get("error"):
|
||||||
raise TracingConfigCheckError()
|
raise TracingConfigCheckError()
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource):
|
||||||
def patch(self, app_id):
|
def patch(self, app_id):
|
||||||
"""Update an existing trace app configuration"""
|
"""Update an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.update_tracing_app_config(
|
result = OpsService.update_tracing_app_config(
|
||||||
app_id=app_id,
|
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||||
tracing_provider=args['tracing_provider'],
|
|
||||||
tracing_config=args['tracing_config']
|
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
|
@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource):
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
"""Delete an existing trace app configuration"""
|
"""Delete an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.delete_tracing_app_config(
|
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||||
app_id=app_id,
|
|
||||||
tracing_provider=args['tracing_provider']
|
|
||||||
)
|
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')
|
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||||
|
|
|
@ -15,23 +15,23 @@ from models.model import Site
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('title', type=str, required=False, location='json')
|
parser.add_argument("title", type=str, required=False, location="json")
|
||||||
parser.add_argument('icon_type', type=str, required=False, location='json')
|
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||||
parser.add_argument('icon', type=str, required=False, location='json')
|
parser.add_argument("icon", type=str, required=False, location="json")
|
||||||
parser.add_argument('icon_background', type=str, required=False, location='json')
|
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||||
parser.add_argument('description', type=str, required=False, location='json')
|
parser.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument('default_language', type=supported_language, required=False, location='json')
|
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||||
parser.add_argument('chat_color_theme', type=str, required=False, location='json')
|
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||||
parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
|
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||||
parser.add_argument('customize_domain', type=str, required=False, location='json')
|
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||||
parser.add_argument('copyright', type=str, required=False, location='json')
|
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||||
parser.add_argument('privacy_policy', type=str, required=False, location='json')
|
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||||
parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
|
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||||
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
|
parser.add_argument(
|
||||||
required=False,
|
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||||
location='json')
|
)
|
||||||
parser.add_argument('prompt_public', type=bool, required=False, location='json')
|
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||||
parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
|
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,26 +48,24 @@ class AppSite(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
site = db.session.query(Site). \
|
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
|
||||||
filter(Site.app_id == app_model.id). \
|
|
||||||
one_or_404()
|
|
||||||
|
|
||||||
for attr_name in [
|
for attr_name in [
|
||||||
'title',
|
"title",
|
||||||
'icon_type',
|
"icon_type",
|
||||||
'icon',
|
"icon",
|
||||||
'icon_background',
|
"icon_background",
|
||||||
'description',
|
"description",
|
||||||
'default_language',
|
"default_language",
|
||||||
'chat_color_theme',
|
"chat_color_theme",
|
||||||
'chat_color_theme_inverted',
|
"chat_color_theme_inverted",
|
||||||
'customize_domain',
|
"customize_domain",
|
||||||
'copyright',
|
"copyright",
|
||||||
'privacy_policy',
|
"privacy_policy",
|
||||||
'custom_disclaimer',
|
"custom_disclaimer",
|
||||||
'customize_token_strategy',
|
"customize_token_strategy",
|
||||||
'prompt_public',
|
"prompt_public",
|
||||||
'show_workflow_steps'
|
"show_workflow_steps",
|
||||||
]:
|
]:
|
||||||
value = args.get(attr_name)
|
value = args.get(attr_name)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
@ -79,7 +77,6 @@ class AppSite(Resource):
|
||||||
|
|
||||||
|
|
||||||
class AppSiteAccessTokenReset(Resource):
|
class AppSiteAccessTokenReset(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -101,5 +98,5 @@ class AppSiteAccessTokenReset(Resource):
|
||||||
return site
|
return site
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
|
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
|
||||||
api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')
|
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")
|
||||||
|
|
|
@ -17,7 +17,6 @@ from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class DailyConversationStatistic(Resource):
|
class DailyConversationStatistic(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||||
FROM messages where app_id = :app_id
|
FROM messages where app_id = :app_id
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||||
'date': str(i.date),
|
|
||||||
'conversation_count': i.conversation_count
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class DailyTerminalsStatistic(Resource):
|
class DailyTerminalsStatistic(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -86,54 +79,49 @@ class DailyTerminalsStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||||
FROM messages where app_id = :app_id
|
FROM messages where app_id = :app_id
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||||
'date': str(i.date),
|
|
||||||
'terminal_count': i.terminal_count
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class DailyTokenCostStatistic(Resource):
|
class DailyTokenCostStatistic(Resource):
|
||||||
|
@ -145,58 +133,53 @@ class DailyTokenCostStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||||
sum(total_price) as total_price
|
sum(total_price) as total_price
|
||||||
FROM messages where app_id = :app_id
|
FROM messages where app_id = :app_id
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append(
|
||||||
'date': str(i.date),
|
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||||
'token_count': i.token_count,
|
)
|
||||||
'total_price': i.total_price,
|
|
||||||
'currency': 'USD'
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class AverageSessionInteractionStatistic(Resource):
|
class AverageSessionInteractionStatistic(Resource):
|
||||||
|
@ -208,8 +191,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
@ -218,30 +201,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||||
FROM conversations c
|
FROM conversations c
|
||||||
JOIN messages m ON c.id = m.conversation_id
|
JOIN messages m ON c.id = m.conversation_id
|
||||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and c.created_at >= :start'
|
sql_query += " and c.created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and c.created_at < :end'
|
sql_query += " and c.created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += """
|
sql_query += """
|
||||||
GROUP BY m.conversation_id) subquery
|
GROUP BY m.conversation_id) subquery
|
||||||
|
@ -254,14 +237,11 @@ ORDER BY date"""
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append(
|
||||||
'date': str(i.date),
|
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class UserSatisfactionRateStatistic(Resource):
|
class UserSatisfactionRateStatistic(Resource):
|
||||||
|
@ -273,57 +253,57 @@ class UserSatisfactionRateStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||||
FROM messages m
|
FROM messages m
|
||||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||||
WHERE m.app_id = :app_id
|
WHERE m.app_id = :app_id
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and m.created_at >= :start'
|
sql_query += " and m.created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and m.created_at < :end'
|
sql_query += " and m.created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append(
|
||||||
'date': str(i.date),
|
{
|
||||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
"date": str(i.date),
|
||||||
})
|
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class AverageResponseTimeStatistic(Resource):
|
class AverageResponseTimeStatistic(Resource):
|
||||||
|
@ -335,56 +315,51 @@ class AverageResponseTimeStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
AVG(provider_response_latency) as latency
|
AVG(provider_response_latency) as latency
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE app_id = :app_id
|
WHERE app_id = :app_id
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||||
'date': str(i.date),
|
|
||||||
'latency': round(i.latency * 1000, 4)
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondStatistic(Resource):
|
class TokensPerSecondStatistic(Resource):
|
||||||
|
@ -396,63 +371,58 @@ class TokensPerSecondStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
CASE
|
CASE
|
||||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||||
END as tokens_per_second
|
END as tokens_per_second
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE app_id = :app_id'''
|
WHERE app_id = :app_id"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||||
'date': str(i.date),
|
|
||||||
'tps': round(i.tokens_per_second, 4)
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
|
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||||
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
|
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||||
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
|
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
|
||||||
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
|
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||||
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
|
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||||
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
|
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
|
||||||
api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
|
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||||
|
|
|
@ -65,50 +65,50 @@ class DraftWorkflowApi(Resource):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
content_type = request.headers.get('Content-Type', '')
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if 'application/json' in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('hash', type=str, required=False, location='json')
|
parser.add_argument("hash", type=str, required=False, location="json")
|
||||||
# TODO: set this to required=True after frontend is updated
|
# TODO: set this to required=True after frontend is updated
|
||||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument('conversation_variables', type=list, required=False, location='json')
|
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif 'text/plain' in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(request.data.decode('utf-8'))
|
data = json.loads(request.data.decode("utf-8"))
|
||||||
if 'graph' not in data or 'features' not in data:
|
if "graph" not in data or "features" not in data:
|
||||||
raise ValueError('graph or features not found in data')
|
raise ValueError("graph or features not found in data")
|
||||||
|
|
||||||
if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
|
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||||
raise ValueError('graph or features is not a dict')
|
raise ValueError("graph or features is not a dict")
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
'graph': data.get('graph'),
|
"graph": data.get("graph"),
|
||||||
'features': data.get('features'),
|
"features": data.get("features"),
|
||||||
'hash': data.get('hash'),
|
"hash": data.get("hash"),
|
||||||
'environment_variables': data.get('environment_variables'),
|
"environment_variables": data.get("environment_variables"),
|
||||||
'conversation_variables': data.get('conversation_variables'),
|
"conversation_variables": data.get("conversation_variables"),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return {'message': 'Invalid JSON data'}, 400
|
return {"message": "Invalid JSON data"}, 400
|
||||||
else:
|
else:
|
||||||
abort(415)
|
abort(415)
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
environment_variables_list = args.get('environment_variables') or []
|
environment_variables_list = args.get("environment_variables") or []
|
||||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||||
conversation_variables_list = args.get('conversation_variables') or []
|
conversation_variables_list = args.get("conversation_variables") or []
|
||||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||||
workflow = workflow_service.sync_draft_workflow(
|
workflow = workflow_service.sync_draft_workflow(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
graph=args['graph'],
|
graph=args["graph"],
|
||||||
features=args['features'],
|
features=args["features"],
|
||||||
unique_hash=args.get('hash'),
|
unique_hash=args.get("hash"),
|
||||||
account=current_user,
|
account=current_user,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
conversation_variables=conversation_variables,
|
||||||
|
@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
|
||||||
return {
|
return {
|
||||||
"result": "success",
|
"result": "success",
|
||||||
"hash": workflow.unique_hash,
|
"hash": workflow.unique_hash,
|
||||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
|
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
workflow = AppDslService.import_and_overwrite_workflow(
|
workflow = AppDslService.import_and_overwrite_workflow(
|
||||||
app_model=app_model,
|
app_model=app_model, data=args["data"], account=current_user
|
||||||
data=args['data'],
|
|
||||||
account=current_user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
@ -164,19 +162,15 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, location='json')
|
parser.add_argument("inputs", type=dict, location="json")
|
||||||
parser.add_argument('query', type=str, required=True, location='json', default='')
|
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||||
parser.add_argument('files', type=list, location='json')
|
parser.add_argument("files", type=list, location="json")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||||
user=current_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -204,16 +199,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, location='json')
|
parser.add_argument("inputs", type=dict, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_iteration(
|
response = AppGenerateService.generate_single_iteration(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||||
user=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
args=args,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -241,16 +233,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, location='json')
|
parser.add_argument("inputs", type=dict, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_single_iteration(
|
response = AppGenerateService.generate_single_iteration(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||||
user=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
args=args,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
class DraftWorkflowRunApi(Resource):
|
class DraftWorkflowRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -278,17 +267,13 @@ class DraftWorkflowRunApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||||
user=current_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -314,9 +299,7 @@ class WorkflowTaskStopApi(Resource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
|
|
||||||
return {
|
return {"result": "success"}
|
||||||
"result": "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DraftWorkflowNodeRunApi(Resource):
|
class DraftWorkflowNodeRunApi(Resource):
|
||||||
|
@ -334,22 +317,18 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||||
app_model=app_model,
|
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||||
node_id=node_id,
|
|
||||||
user_inputs=args.get('inputs'),
|
|
||||||
account=current_user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
|
||||||
class PublishedWorkflowApi(Resource):
|
class PublishedWorkflowApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -385,10 +364,7 @@ class PublishedWorkflowApi(Resource):
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||||
|
|
||||||
return {
|
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||||
"result": "success",
|
|
||||||
"created_at": TimestampField().format(workflow.created_at)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultBlockConfigsApi(Resource):
|
class DefaultBlockConfigsApi(Resource):
|
||||||
|
@ -423,22 +399,19 @@ class DefaultBlockConfigApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('q', type=str, location='args')
|
parser.add_argument("q", type=str, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
filters = None
|
filters = None
|
||||||
if args.get('q'):
|
if args.get("q"):
|
||||||
try:
|
try:
|
||||||
filters = json.loads(args.get('q'))
|
filters = json.loads(args.get("q"))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError('Invalid filters')
|
raise ValueError("Invalid filters")
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
return workflow_service.get_default_block_config(
|
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||||
node_type=block_type,
|
|
||||||
filters=filters
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertToWorkflowApi(Resource):
|
class ConvertToWorkflowApi(Resource):
|
||||||
|
@ -458,38 +431,40 @@ class ConvertToWorkflowApi(Resource):
|
||||||
|
|
||||||
if request.data:
|
if request.data:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
args = {}
|
args = {}
|
||||||
|
|
||||||
# convert to workflow mode
|
# convert to workflow mode
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
new_app_model = workflow_service.convert_to_workflow(
|
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
# return app id
|
# return app id
|
||||||
return {
|
return {
|
||||||
'new_app_id': new_app_model.id,
|
"new_app_id": new_app_model.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
|
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||||
api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
|
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
|
||||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
|
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||||
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
|
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||||
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
|
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||||
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
|
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||||
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
|
api.add_resource(
|
||||||
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
|
AdvancedChatDraftRunIterationNodeApi,
|
||||||
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
|
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||||
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
|
)
|
||||||
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
|
api.add_resource(
|
||||||
'/<string:block_type>')
|
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||||
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
|
)
|
||||||
|
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||||
|
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||||
|
api.add_resource(
|
||||||
|
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
|
||||||
|
)
|
||||||
|
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||||
|
|
|
@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource):
|
||||||
Get workflow app logs
|
Get workflow app logs
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('keyword', type=str, location='args')
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
|
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||||
app_model=app_model,
|
app_model=app_model, args=args
|
||||||
args=args
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow_app_log_pagination
|
return workflow_app_log_pagination
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
|
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
|
||||||
|
|
|
@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||||
Get advanced chat app workflow run list
|
Get advanced chat app workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||||
app_model=app_model,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource):
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_workflow_runs(
|
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||||
app_model=app_model,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
||||||
|
|
||||||
return {
|
return {"data": node_executions}
|
||||||
'data': node_executions
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
|
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||||
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
|
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
|
||||||
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
|
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||||
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
|
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||||
|
|
|
@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||||
FROM workflow_runs
|
FROM workflow_runs
|
||||||
WHERE app_id = :app_id
|
WHERE app_id = :app_id
|
||||||
AND triggered_from = :triggered_from
|
AND triggered_from = :triggered_from
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
arg_dict = {
|
||||||
|
"tz": account.timezone,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||||
|
}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||||
'date': str(i.date),
|
|
||||||
'runs': i.runs
|
return jsonify({"data": response_data})
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
class WorkflowDailyTerminalsStatistic(Resource):
|
class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||||
FROM workflow_runs
|
FROM workflow_runs
|
||||||
WHERE app_id = :app_id
|
WHERE app_id = :app_id
|
||||||
AND triggered_from = :triggered_from
|
AND triggered_from = :triggered_from
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
arg_dict = {
|
||||||
|
"tz": account.timezone,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||||
|
}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||||
'date': str(i.date),
|
|
||||||
'terminal_count': i.terminal_count
|
return jsonify({"data": response_data})
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
class WorkflowDailyTokenCostStatistic(Resource):
|
class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = '''
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
SUM(workflow_runs.total_tokens) as token_count
|
SUM(workflow_runs.total_tokens) as token_count
|
||||||
FROM workflow_runs
|
FROM workflow_runs
|
||||||
WHERE app_id = :app_id
|
WHERE app_id = :app_id
|
||||||
AND triggered_from = :triggered_from
|
AND triggered_from = :triggered_from
|
||||||
'''
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
arg_dict = {
|
||||||
|
"tz": account.timezone,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||||
|
}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at >= :start'
|
sql_query += " and created_at >= :start"
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += ' and created_at < :end'
|
sql_query += " and created_at < :end"
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += ' GROUP BY date order by date'
|
sql_query += " GROUP BY date order by date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append(
|
||||||
'date': str(i.date),
|
{
|
||||||
'token_count': i.token_count,
|
"date": str(i.date),
|
||||||
})
|
"token_count": i.token_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
|
@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
GROUP BY date, c.created_by) sub
|
GROUP BY date, c.created_by) sub
|
||||||
GROUP BY sub.date
|
GROUP BY sub.date
|
||||||
"""
|
"""
|
||||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
arg_dict = {
|
||||||
|
"tz": account.timezone,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||||
|
}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
if args['start']:
|
if args["start"]:
|
||||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||||
start_datetime = start_datetime.replace(second=0)
|
start_datetime = start_datetime.replace(second=0)
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
|
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||||
arg_dict['start'] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
else:
|
else:
|
||||||
sql_query = sql_query.replace('{{start}}', '')
|
sql_query = sql_query.replace("{{start}}", "")
|
||||||
|
|
||||||
if args['end']:
|
if args["end"]:
|
||||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||||
end_datetime = end_datetime.replace(second=0)
|
end_datetime = end_datetime.replace(second=0)
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
|
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||||
arg_dict['end'] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
else:
|
else:
|
||||||
sql_query = sql_query.replace('{{end}}', '')
|
sql_query = sql_query.replace("{{end}}", "")
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||||
for i in rs:
|
for i in rs:
|
||||||
response_data.append({
|
response_data.append(
|
||||||
'date': str(i.date),
|
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"data": response_data})
|
||||||
'data': response_data
|
|
||||||
})
|
|
||||||
|
|
||||||
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
|
|
||||||
api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
|
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||||
api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
|
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||||
api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
|
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||||
|
api.add_resource(
|
||||||
|
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
|
||||||
|
)
|
||||||
|
|
|
@ -8,24 +8,23 @@ from libs.login import current_user
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Optional[Callable] = None, *,
|
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||||
mode: Union[AppMode, list[AppMode]] = None):
|
|
||||||
def decorator(view_func):
|
def decorator(view_func):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args, **kwargs):
|
def decorated_view(*args, **kwargs):
|
||||||
if not kwargs.get('app_id'):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError('missing app_id in path parameters')
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
app_id = kwargs.get('app_id')
|
app_id = kwargs.get("app_id")
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|
||||||
del kwargs['app_id']
|
del kwargs["app_id"]
|
||||||
|
|
||||||
app_model = db.session.query(App).filter(
|
app_model = (
|
||||||
App.id == app_id,
|
db.session.query(App)
|
||||||
App.tenant_id == current_user.current_tenant_id,
|
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||||
App.status == 'normal'
|
.first()
|
||||||
).first()
|
)
|
||||||
|
|
||||||
if not app_model:
|
if not app_model:
|
||||||
raise AppNotFoundError()
|
raise AppNotFoundError()
|
||||||
|
@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *,
|
||||||
mode_values = {m.value for m in modes}
|
mode_values = {m.value for m in modes}
|
||||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||||
|
|
||||||
kwargs['app_model'] = app_model
|
kwargs["app_model"] = app_model
|
||||||
|
|
||||||
return view_func(*args, **kwargs)
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_view
|
return decorated_view
|
||||||
|
|
||||||
if view is None:
|
if view is None:
|
||||||
|
|
|
@ -17,60 +17,61 @@ from services.account_service import RegisterService
|
||||||
class ActivateCheckApi(Resource):
|
class ActivateCheckApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
|
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
|
||||||
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
|
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
|
||||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
workspaceId = args['workspace_id']
|
workspaceId = args["workspace_id"]
|
||||||
reg_email = args['email']
|
reg_email = args["email"]
|
||||||
token = args['token']
|
token = args["token"]
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||||
|
|
||||||
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
|
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
|
||||||
|
|
||||||
|
|
||||||
class ActivateApi(Resource):
|
class ActivateApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
|
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
parser.add_argument(
|
||||||
location='json')
|
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
)
|
||||||
|
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||||
if invitation is None:
|
if invitation is None:
|
||||||
raise AlreadyActivateError()
|
raise AlreadyActivateError()
|
||||||
|
|
||||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||||
|
|
||||||
account = invitation['account']
|
account = invitation["account"]
|
||||||
account.name = args['name']
|
account.name = args["name"]
|
||||||
|
|
||||||
# generate password salt
|
# generate password salt
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
base64_salt = base64.b64encode(salt).decode()
|
base64_salt = base64.b64encode(salt).decode()
|
||||||
|
|
||||||
# encrypt password with salt
|
# encrypt password with salt
|
||||||
password_hashed = hash_password(args['password'], salt)
|
password_hashed = hash_password(args["password"], salt)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
account.password = base64_password_hashed
|
account.password = base64_password_hashed
|
||||||
account.password_salt = base64_salt
|
account.password_salt = base64_salt
|
||||||
account.interface_language = args['interface_language']
|
account.interface_language = args["interface_language"]
|
||||||
account.timezone = args['timezone']
|
account.timezone = args["timezone"]
|
||||||
account.interface_theme = 'light'
|
account.interface_theme = "light"
|
||||||
account.status = AccountStatus.ACTIVE.value
|
account.status = AccountStatus.ACTIVE.value
|
||||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||||
api.add_resource(ActivateApi, '/activate')
|
api.add_resource(ActivateApi, "/activate")
|
||||||
|
|
|
@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource):
|
||||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||||
if data_source_api_key_bindings:
|
if data_source_api_key_bindings:
|
||||||
return {
|
return {
|
||||||
'sources': [{
|
"sources": [
|
||||||
'id': data_source_api_key_binding.id,
|
{
|
||||||
'category': data_source_api_key_binding.category,
|
"id": data_source_api_key_binding.id,
|
||||||
'provider': data_source_api_key_binding.provider,
|
"category": data_source_api_key_binding.category,
|
||||||
'disabled': data_source_api_key_binding.disabled,
|
"provider": data_source_api_key_binding.provider,
|
||||||
'created_at': int(data_source_api_key_binding.created_at.timestamp()),
|
"disabled": data_source_api_key_binding.disabled,
|
||||||
'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
|
"created_at": int(data_source_api_key_binding.created_at.timestamp()),
|
||||||
|
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
|
||||||
}
|
}
|
||||||
for data_source_api_key_binding in
|
for data_source_api_key_binding in data_source_api_key_bindings
|
||||||
data_source_api_key_bindings]
|
]
|
||||||
}
|
}
|
||||||
return {'sources': []}
|
return {"sources": []}
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthDataSourceBinding(Resource):
|
class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
|
@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||||
|
@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||||
|
|
||||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||||
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||||
|
|
|
@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
|
||||||
|
|
||||||
def get_oauth_providers():
|
def get_oauth_providers():
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
|
notion_oauth = NotionOAuth(
|
||||||
|
client_id=dify_config.NOTION_CLIENT_ID,
|
||||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
|
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
|
||||||
|
)
|
||||||
|
|
||||||
OAUTH_PROVIDERS = {
|
OAUTH_PROVIDERS = {"notion": notion_oauth}
|
||||||
'notion': notion_oauth
|
|
||||||
}
|
|
||||||
return OAUTH_PROVIDERS
|
return OAUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,18 +37,16 @@ class OAuthDataSource(Resource):
|
||||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||||
print(vars(oauth_provider))
|
print(vars(oauth_provider))
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
|
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
|
||||||
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
||||||
if not internal_secret:
|
if not internal_secret:
|
||||||
return {'error': 'Internal secret is not set'},
|
return ({"error": "Internal secret is not set"},)
|
||||||
oauth_provider.save_internal_access_token(internal_secret)
|
oauth_provider.save_internal_access_token(internal_secret)
|
||||||
return { 'data': '' }
|
return {"data": ""}
|
||||||
else:
|
else:
|
||||||
auth_url = oauth_provider.get_authorization_url()
|
auth_url = oauth_provider.get_authorization_url()
|
||||||
return { 'data': auth_url }, 200
|
return {"data": auth_url}, 200
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthDataSourceCallback(Resource):
|
class OAuthDataSourceCallback(Resource):
|
||||||
|
@ -57,17 +55,17 @@ class OAuthDataSourceCallback(Resource):
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
if 'code' in request.args:
|
if "code" in request.args:
|
||||||
code = request.args.get('code')
|
code = request.args.get("code")
|
||||||
|
|
||||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
|
||||||
elif 'error' in request.args:
|
elif "error" in request.args:
|
||||||
error = request.args.get('error')
|
error = request.args.get("error")
|
||||||
|
|
||||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
|
||||||
else:
|
else:
|
||||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
||||||
|
|
||||||
|
|
||||||
class OAuthDataSourceBinding(Resource):
|
class OAuthDataSourceBinding(Resource):
|
||||||
|
@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource):
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
if 'code' in request.args:
|
if "code" in request.args:
|
||||||
code = request.args.get('code')
|
code = request.args.get("code")
|
||||||
try:
|
try:
|
||||||
oauth_provider.get_access_token(code)
|
oauth_provider.get_access_token(code)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
|
||||||
return {'error': 'OAuth data source process failed'}, 400
|
)
|
||||||
|
return {"error": "OAuth data source process failed"}, 400
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class OAuthDataSourceSync(Resource):
|
class OAuthDataSourceSync(Resource):
|
||||||
|
@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource):
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
try:
|
try:
|
||||||
oauth_provider.sync_data_source(binding_id)
|
oauth_provider.sync_data_source(binding_id)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logging.exception(
|
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
return {"error": "OAuth data source process failed"}, 400
|
||||||
return {'error': 'OAuth data source process failed'}, 400
|
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||||
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
|
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||||
|
|
|
@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthFailedError(BaseHTTPException):
|
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||||
error_code = 'auth_failed'
|
error_code = "auth_failed"
|
||||||
description = "{message}"
|
description = "{message}"
|
||||||
code = 500
|
code = 500
|
||||||
|
|
||||||
|
|
||||||
class InvalidEmailError(BaseHTTPException):
|
class InvalidEmailError(BaseHTTPException):
|
||||||
error_code = 'invalid_email'
|
error_code = "invalid_email"
|
||||||
description = "The email address is not valid."
|
description = "The email address is not valid."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class PasswordMismatchError(BaseHTTPException):
|
class PasswordMismatchError(BaseHTTPException):
|
||||||
error_code = 'password_mismatch'
|
error_code = "password_mismatch"
|
||||||
description = "The passwords do not match."
|
description = "The passwords do not match."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class InvalidTokenError(BaseHTTPException):
|
class InvalidTokenError(BaseHTTPException):
|
||||||
error_code = 'invalid_or_expired_token'
|
error_code = "invalid_or_expired_token"
|
||||||
description = "The token is invalid or has expired."
|
description = "The token is invalid or has expired."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||||
error_code = 'password_reset_rate_limit_exceeded'
|
error_code = "password_reset_rate_limit_exceeded"
|
||||||
description = "Password reset rate limit exceeded. Try again later."
|
description = "Password reset rate limit exceeded. Try again later."
|
||||||
code = 429
|
code = 429
|
||||||
|
|
||||||
|
|
|
@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('email', type=str, required=True, location='json')
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
email = args['email']
|
email = args["email"]
|
||||||
|
|
||||||
if not email_validate(email):
|
if not email_validate(email):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
token = args['token']
|
token = args["token"]
|
||||||
|
|
||||||
reset_data = AccountService.get_reset_password_data(token)
|
reset_data = AccountService.get_reset_password_data(token)
|
||||||
|
|
||||||
if reset_data is None:
|
if reset_data is None:
|
||||||
return {'is_valid': False, 'email': None}
|
return {"is_valid": False, "email": None}
|
||||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
return {"is_valid": True, "email": reset_data.get("email")}
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
new_password = args['new_password']
|
new_password = args["new_password"]
|
||||||
password_confirm = args['password_confirm']
|
password_confirm = args["password_confirm"]
|
||||||
|
|
||||||
if str(new_password).strip() != str(password_confirm).strip():
|
if str(new_password).strip() != str(password_confirm).strip():
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
token = args['token']
|
token = args["token"]
|
||||||
reset_data = AccountService.get_reset_password_data(token)
|
reset_data = AccountService.get_reset_password_data(token)
|
||||||
|
|
||||||
if reset_data is None:
|
if reset_data is None:
|
||||||
|
@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource):
|
||||||
password_hashed = hash_password(new_password, salt)
|
password_hashed = hash_password(new_password, salt)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
|
|
||||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||||
account.password = base64_password_hashed
|
account.password = base64_password_hashed
|
||||||
account.password_salt = base64_salt
|
account.password_salt = base64_salt
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||||
|
|
|
@ -20,37 +20,39 @@ class LoginApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('email', type=email, required=True, location='json')
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument('password', type=valid_password, required=True, location='json')
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
|
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# todo: Verify the recaptcha
|
# todo: Verify the recaptcha
|
||||||
|
|
||||||
try:
|
try:
|
||||||
account = AccountService.authenticate(args['email'], args['password'])
|
account = AccountService.authenticate(args["email"], args["password"])
|
||||||
except services.errors.account.AccountLoginError as e:
|
except services.errors.account.AccountLoginError as e:
|
||||||
return {'code': 'unauthorized', 'message': str(e)}, 401
|
return {"code": "unauthorized", "message": str(e)}, 401
|
||||||
|
|
||||||
# SELF_HOSTED only have one workspace
|
# SELF_HOSTED only have one workspace
|
||||||
tenants = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
if len(tenants) == 0:
|
if len(tenants) == 0:
|
||||||
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
return {
|
||||||
|
"result": "fail",
|
||||||
|
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||||
|
}
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||||
|
|
||||||
return {'result': 'success', 'data': token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def get(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
account = cast(Account, flask_login.current_user)
|
||||||
token = request.headers.get('Authorization', '').split(' ')[1]
|
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||||
AccountService.logout(account=account, token=token)
|
AccountService.logout(account=account, token=token)
|
||||||
flask_login.logout_user()
|
flask_login.logout_user()
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ResetPasswordApi(Resource):
|
class ResetPasswordApi(Resource):
|
||||||
|
@ -101,8 +103,8 @@ class ResetPasswordApi(Resource):
|
||||||
# # handle error
|
# # handle error
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(LoginApi, '/login')
|
api.add_resource(LoginApi, "/login")
|
||||||
api.add_resource(LogoutApi, '/logout')
|
api.add_resource(LogoutApi, "/logout")
|
||||||
|
|
|
@ -25,7 +25,7 @@ def get_oauth_providers():
|
||||||
github_oauth = GitHubOAuth(
|
github_oauth = GitHubOAuth(
|
||||||
client_id=dify_config.GITHUB_CLIENT_ID,
|
client_id=dify_config.GITHUB_CLIENT_ID,
|
||||||
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
||||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
|
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
|
||||||
)
|
)
|
||||||
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
||||||
google_oauth = None
|
google_oauth = None
|
||||||
|
@ -33,10 +33,10 @@ def get_oauth_providers():
|
||||||
google_oauth = GoogleOAuth(
|
google_oauth = GoogleOAuth(
|
||||||
client_id=dify_config.GOOGLE_CLIENT_ID,
|
client_id=dify_config.GOOGLE_CLIENT_ID,
|
||||||
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
||||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
|
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
|
||||||
)
|
)
|
||||||
|
|
||||||
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
|
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
|
||||||
return OAUTH_PROVIDERS
|
return OAUTH_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class OAuthLogin(Resource):
|
||||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||||
print(vars(oauth_provider))
|
print(vars(oauth_provider))
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
|
|
||||||
auth_url = oauth_provider.get_authorization_url()
|
auth_url = oauth_provider.get_authorization_url()
|
||||||
return redirect(auth_url)
|
return redirect(auth_url)
|
||||||
|
@ -59,20 +59,20 @@ class OAuthCallback(Resource):
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {'error': 'Invalid provider'}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
|
|
||||||
code = request.args.get('code')
|
code = request.args.get("code")
|
||||||
try:
|
try:
|
||||||
token = oauth_provider.get_access_token(code)
|
token = oauth_provider.get_access_token(code)
|
||||||
user_info = oauth_provider.get_user_info(token)
|
user_info = oauth_provider.get_user_info(token)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
|
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||||
return {'error': 'OAuth process failed'}, 400
|
return {"error": "OAuth process failed"}, 400
|
||||||
|
|
||||||
account = _generate_account(provider, user_info)
|
account = _generate_account(provider, user_info)
|
||||||
# Check account status
|
# Check account status
|
||||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||||
return {'error': 'Account is banned or closed.'}, 403
|
return {"error": "Account is banned or closed."}, 403
|
||||||
|
|
||||||
if account.status == AccountStatus.PENDING.value:
|
if account.status == AccountStatus.PENDING.value:
|
||||||
account.status = AccountStatus.ACTIVE.value
|
account.status = AccountStatus.ACTIVE.value
|
||||||
|
@ -83,7 +83,7 @@ class OAuthCallback(Resource):
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||||
|
|
||||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||||
|
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
# Create account
|
# Create account
|
||||||
account_name = user_info.name if user_info.name else 'Dify'
|
account_name = user_info.name if user_info.name else "Dify"
|
||||||
account = RegisterService.register(
|
account = RegisterService.register(
|
||||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||||
)
|
)
|
||||||
|
@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(OAuthLogin, '/oauth/login/<provider>')
|
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||||
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
|
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||||
|
|
|
@ -9,28 +9,24 @@ from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Resource):
|
class Subscription(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
|
||||||
return BillingService.get_subscription(args['plan'],
|
return BillingService.get_subscription(
|
||||||
args['interval'],
|
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||||
current_user.email,
|
)
|
||||||
current_user.current_tenant_id)
|
|
||||||
|
|
||||||
|
|
||||||
class Invoices(Resource):
|
class Invoices(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -40,5 +36,5 @@ class Invoices(Resource):
|
||||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(Subscription, '/billing/subscription')
|
api.add_resource(Subscription, "/billing/subscription")
|
||||||
api.add_resource(Invoices, '/billing/invoices')
|
api.add_resource(Invoices, "/billing/invoices")
|
||||||
|
|
|
@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
|
||||||
class DataSourceApi(Resource):
|
class DataSourceApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
|
data_source_integrates = (
|
||||||
|
db.session.query(DataSourceOauthBinding)
|
||||||
|
.filter(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceOauthBinding.disabled == False
|
DataSourceOauthBinding.disabled == False,
|
||||||
).all()
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
base_url = request.url_root.rstrip('/')
|
base_url = request.url_root.rstrip("/")
|
||||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||||
providers = ["notion"]
|
providers = ["notion"]
|
||||||
|
|
||||||
|
@ -44,26 +47,30 @@ class DataSourceApi(Resource):
|
||||||
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
||||||
if existing_integrates:
|
if existing_integrates:
|
||||||
for existing_integrate in list(existing_integrates):
|
for existing_integrate in list(existing_integrates):
|
||||||
integrate_data.append({
|
integrate_data.append(
|
||||||
'id': existing_integrate.id,
|
{
|
||||||
'provider': provider,
|
"id": existing_integrate.id,
|
||||||
'created_at': existing_integrate.created_at,
|
"provider": provider,
|
||||||
'is_bound': True,
|
"created_at": existing_integrate.created_at,
|
||||||
'disabled': existing_integrate.disabled,
|
"is_bound": True,
|
||||||
'source_info': existing_integrate.source_info,
|
"disabled": existing_integrate.disabled,
|
||||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
"source_info": existing_integrate.source_info,
|
||||||
})
|
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
integrate_data.append({
|
integrate_data.append(
|
||||||
'id': None,
|
{
|
||||||
'provider': provider,
|
"id": None,
|
||||||
'created_at': None,
|
"provider": provider,
|
||||||
'source_info': None,
|
"created_at": None,
|
||||||
'is_bound': False,
|
"source_info": None,
|
||||||
'disabled': None,
|
"is_bound": False,
|
||||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
"disabled": None,
|
||||||
})
|
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||||
return {'data': integrate_data}, 200
|
}
|
||||||
|
)
|
||||||
|
return {"data": integrate_data}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -71,92 +78,82 @@ class DataSourceApi(Resource):
|
||||||
def patch(self, binding_id, action):
|
def patch(self, binding_id, action):
|
||||||
binding_id = str(binding_id)
|
binding_id = str(binding_id)
|
||||||
action = str(action)
|
action = str(action)
|
||||||
data_source_binding = DataSourceOauthBinding.query.filter_by(
|
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||||
id=binding_id
|
|
||||||
).first()
|
|
||||||
if data_source_binding is None:
|
if data_source_binding is None:
|
||||||
raise NotFound('Data source binding not found.')
|
raise NotFound("Data source binding not found.")
|
||||||
# enable binding
|
# enable binding
|
||||||
if action == 'enable':
|
if action == "enable":
|
||||||
if data_source_binding.disabled:
|
if data_source_binding.disabled:
|
||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
db.session.add(data_source_binding)
|
db.session.add(data_source_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
raise ValueError('Data source is not disabled.')
|
raise ValueError("Data source is not disabled.")
|
||||||
# disable binding
|
# disable binding
|
||||||
if action == 'disable':
|
if action == "disable":
|
||||||
if not data_source_binding.disabled:
|
if not data_source_binding.disabled:
|
||||||
data_source_binding.disabled = True
|
data_source_binding.disabled = True
|
||||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
db.session.add(data_source_binding)
|
db.session.add(data_source_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
raise ValueError('Data source is disabled.')
|
raise ValueError("Data source is disabled.")
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionListApi(Resource):
|
class DataSourceNotionListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_notion_info_list_fields)
|
@marshal_with(integrate_notion_info_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
dataset_id = request.args.get('dataset_id', default=None, type=str)
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||||
exist_page_ids = []
|
exist_page_ids = []
|
||||||
# import notion in the exist dataset
|
# import notion in the exist dataset
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
if dataset.data_source_type != 'notion_import':
|
if dataset.data_source_type != "notion_import":
|
||||||
raise ValueError('Dataset is not notion type.')
|
raise ValueError("Dataset is not notion type.")
|
||||||
documents = Document.query.filter_by(
|
documents = Document.query.filter_by(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
data_source_type='notion_import',
|
data_source_type="notion_import",
|
||||||
enabled=True
|
enabled=True,
|
||||||
).all()
|
).all()
|
||||||
if documents:
|
if documents:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
data_source_info = json.loads(document.data_source_info)
|
data_source_info = json.loads(document.data_source_info)
|
||||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||||
provider='notion',
|
|
||||||
disabled=False
|
|
||||||
).all()
|
).all()
|
||||||
if not data_source_bindings:
|
if not data_source_bindings:
|
||||||
return {
|
return {"notion_info": []}, 200
|
||||||
'notion_info': []
|
|
||||||
}, 200
|
|
||||||
pre_import_info_list = []
|
pre_import_info_list = []
|
||||||
for data_source_binding in data_source_bindings:
|
for data_source_binding in data_source_bindings:
|
||||||
source_info = data_source_binding.source_info
|
source_info = data_source_binding.source_info
|
||||||
pages = source_info['pages']
|
pages = source_info["pages"]
|
||||||
# Filter out already bound pages
|
# Filter out already bound pages
|
||||||
for page in pages:
|
for page in pages:
|
||||||
if page['page_id'] in exist_page_ids:
|
if page["page_id"] in exist_page_ids:
|
||||||
page['is_bound'] = True
|
page["is_bound"] = True
|
||||||
else:
|
else:
|
||||||
page['is_bound'] = False
|
page["is_bound"] = False
|
||||||
pre_import_info = {
|
pre_import_info = {
|
||||||
'workspace_name': source_info['workspace_name'],
|
"workspace_name": source_info["workspace_name"],
|
||||||
'workspace_icon': source_info['workspace_icon'],
|
"workspace_icon": source_info["workspace_icon"],
|
||||||
'workspace_id': source_info['workspace_id'],
|
"workspace_id": source_info["workspace_id"],
|
||||||
'pages': pages,
|
"pages": pages,
|
||||||
}
|
}
|
||||||
pre_import_info_list.append(pre_import_info)
|
pre_import_info_list.append(pre_import_info)
|
||||||
return {
|
return {"notion_info": pre_import_info_list}, 200
|
||||||
'notion_info': pre_import_info_list
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionApi(Resource):
|
class DataSourceNotionApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource):
|
||||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceOauthBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == "notion",
|
||||||
DataSourceOauthBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
raise NotFound('Data source binding not found.')
|
raise NotFound("Data source binding not found.")
|
||||||
|
|
||||||
extractor = NotionExtractor(
|
extractor = NotionExtractor(
|
||||||
notion_workspace_id=workspace_id,
|
notion_workspace_id=workspace_id,
|
||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=data_source_binding.access_token,
|
notion_access_token=data_source_binding.access_token,
|
||||||
tenant_id=current_user.current_tenant_id
|
tenant_id=current_user.current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_docs = extractor.extract()
|
text_docs = extractor.extract()
|
||||||
return {
|
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
parser.add_argument(
|
||||||
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
notion_info_list = args['notion_info_list']
|
notion_info_list = args["notion_info_list"]
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info['workspace_id']
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info['pages']:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type="notion_import",
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page['page_id'],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page['type'],
|
"notion_page_type": page["type"],
|
||||||
"tenant_id": current_user.current_tenant_id
|
"tenant_id": current_user.current_tenant_id,
|
||||||
},
|
},
|
||||||
document_model=args['doc_form']
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
response = indexing_runner.indexing_estimate(
|
||||||
args['process_rule'], args['doc_form'],
|
current_user.current_tenant_id,
|
||||||
args['doc_language'])
|
extract_settings,
|
||||||
|
args["process_rule"],
|
||||||
|
args["doc_form"],
|
||||||
|
args["doc_language"],
|
||||||
|
)
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionDatasetSyncApi(Resource):
|
class DataSourceNotionDatasetSyncApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionDocumentSyncApi(Resource):
|
class DataSourceNotionDocumentSyncApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||||
return 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
|
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||||
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
|
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||||
api.add_resource(DataSourceNotionApi,
|
api.add_resource(
|
||||||
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
|
DataSourceNotionApi,
|
||||||
'/datasets/notion-indexing-estimate')
|
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||||
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
|
"/datasets/notion-indexing-estimate",
|
||||||
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
|
)
|
||||||
|
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||||
|
api.add_resource(
|
||||||
|
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||||
|
)
|
||||||
|
|
|
@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
raise ValueError('Name must be between 1 to 40 characters.')
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
def _validate_description_length(description):
|
||||||
if len(description) > 400:
|
if len(description) > 400:
|
||||||
raise ValueError('Description cannot exceed 400 characters.')
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
return description
|
return description
|
||||||
|
|
||||||
|
|
||||||
class DatasetListApi(Resource):
|
class DatasetListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
ids = request.args.getlist('ids')
|
ids = request.args.getlist("ids")
|
||||||
provider = request.args.get('provider', default="vendor")
|
provider = request.args.get("provider", default="vendor")
|
||||||
search = request.args.get('keyword', default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
tag_ids = request.args.getlist('tag_ids')
|
tag_ids = request.args.getlist("tag_ids")
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||||
else:
|
else:
|
||||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
datasets, total = DatasetService.get_datasets(
|
||||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
|
||||||
|
)
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(
|
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_models = configurations.get_models(
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
only_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for embedding_model in embedding_models:
|
for embedding_model in embedding_models:
|
||||||
|
@ -77,28 +72,22 @@ class DatasetListApi(Resource):
|
||||||
|
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
for item in data:
|
for item in data:
|
||||||
if item['indexing_technique'] == 'high_quality':
|
if item["indexing_technique"] == "high_quality":
|
||||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||||
if item_model in model_names:
|
if item_model in model_names:
|
||||||
item['embedding_available'] = True
|
item["embedding_available"] = True
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = False
|
item["embedding_available"] = False
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = True
|
item["embedding_available"] = True
|
||||||
|
|
||||||
if item.get('permission') == 'partial_members':
|
if item.get("permission") == "partial_members":
|
||||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
|
||||||
item.update({'partial_member_list': part_users_list})
|
item.update({"partial_member_list": part_users_list})
|
||||||
else:
|
else:
|
||||||
item.update({'partial_member_list': []})
|
item.update({"partial_member_list": []})
|
||||||
|
|
||||||
response = {
|
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||||
'data': data,
|
|
||||||
'has_more': len(datasets) == limit,
|
|
||||||
'limit': limit,
|
|
||||||
'total': total,
|
|
||||||
'page': page
|
|
||||||
}
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -106,13 +95,21 @@ class DatasetListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False, required=True,
|
parser.add_argument(
|
||||||
help='type is required. Name must be between 1 to 40 characters.',
|
"name",
|
||||||
type=_validate_name)
|
nullable=False,
|
||||||
parser.add_argument('indexing_technique', type=str, location='json',
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"indexing_technique",
|
||||||
|
type=str,
|
||||||
|
location="json",
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
help='Invalid indexing technique.')
|
help="Invalid indexing technique.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
@ -122,9 +119,9 @@ class DatasetListApi(Resource):
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=args['name'],
|
name=args["name"],
|
||||||
indexing_technique=args['indexing_technique'],
|
indexing_technique=args["indexing_technique"],
|
||||||
account=current_user
|
account=current_user,
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
@ -142,42 +139,36 @@ class DatasetApi(Resource):
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
try:
|
try:
|
||||||
DatasetService.check_dataset_permission(
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
dataset, current_user)
|
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
data = marshal(dataset, dataset_detail_fields)
|
data = marshal(dataset, dataset_detail_fields)
|
||||||
if data.get('permission') == 'partial_members':
|
if data.get("permission") == "partial_members":
|
||||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
data.update({'partial_member_list': part_users_list})
|
data.update({"partial_member_list": part_users_list})
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(
|
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_models = configurations.get_models(
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
only_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for embedding_model in embedding_models:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
if data['indexing_technique'] == 'high_quality':
|
if data["indexing_technique"] == "high_quality":
|
||||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||||
if item_model in model_names:
|
if item_model in model_names:
|
||||||
data['embedding_available'] = True
|
data["embedding_available"] = True
|
||||||
else:
|
else:
|
||||||
data['embedding_available'] = False
|
data["embedding_available"] = False
|
||||||
else:
|
else:
|
||||||
data['embedding_available'] = True
|
data["embedding_available"] = True
|
||||||
|
|
||||||
if data.get('permission') == 'partial_members':
|
if data.get("permission") == "partial_members":
|
||||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
data.update({'partial_member_list': part_users_list})
|
data.update({"partial_member_list": part_users_list})
|
||||||
|
|
||||||
return data, 200
|
return data, 200
|
||||||
|
|
||||||
|
@ -191,42 +182,49 @@ class DatasetApi(Resource):
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False,
|
parser.add_argument(
|
||||||
help='type is required. Name must be between 1 to 40 characters.',
|
"name",
|
||||||
type=_validate_name)
|
nullable=False,
|
||||||
parser.add_argument('description',
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
location='json', store_missing=False,
|
type=_validate_name,
|
||||||
type=_validate_description_length)
|
)
|
||||||
parser.add_argument('indexing_technique', type=str, location='json',
|
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||||
|
parser.add_argument(
|
||||||
|
"indexing_technique",
|
||||||
|
type=str,
|
||||||
|
location="json",
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
help='Invalid indexing technique.')
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument('permission', type=str, location='json', choices=(
|
|
||||||
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
|
|
||||||
)
|
)
|
||||||
parser.add_argument('embedding_model', type=str,
|
parser.add_argument(
|
||||||
location='json', help='Invalid embedding model.')
|
"permission",
|
||||||
parser.add_argument('embedding_model_provider', type=str,
|
type=str,
|
||||||
location='json', help='Invalid embedding model provider.')
|
location="json",
|
||||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
help="Invalid permission.",
|
||||||
|
)
|
||||||
|
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||||
|
)
|
||||||
|
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||||
|
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if data.get('indexing_technique') == 'high_quality':
|
if data.get("indexing_technique") == "high_quality":
|
||||||
DatasetService.check_embedding_model_setting(dataset.tenant_id,
|
DatasetService.check_embedding_model_setting(
|
||||||
data.get('embedding_model_provider'),
|
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||||
data.get('embedding_model')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
DatasetPermissionService.check_permission(
|
DatasetPermissionService.check_permission(
|
||||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = DatasetService.update_dataset(
|
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||||
dataset_id_str, args, current_user)
|
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
@ -234,16 +232,19 @@ class DatasetApi(Resource):
|
||||||
result_data = marshal(dataset, dataset_detail_fields)
|
result_data = marshal(dataset, dataset_detail_fields)
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||||
)
|
)
|
||||||
# clear partial member list when permission is only_me or all_team_members
|
# clear partial member list when permission is only_me or all_team_members
|
||||||
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
|
elif (
|
||||||
|
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||||
|
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||||
|
):
|
||||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||||
|
|
||||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
result_data.update({'partial_member_list': partial_member_list})
|
result_data.update({"partial_member_list": partial_member_list})
|
||||||
|
|
||||||
return result_data, 200
|
return result_data, 200
|
||||||
|
|
||||||
|
@ -260,12 +261,13 @@ class DatasetApi(Resource):
|
||||||
try:
|
try:
|
||||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
else:
|
else:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
except services.errors.dataset.DatasetInUseError:
|
except services.errors.dataset.DatasetInUseError:
|
||||||
raise DatasetInUseError()
|
raise DatasetInUseError()
|
||||||
|
|
||||||
|
|
||||||
class DatasetUseCheckApi(Resource):
|
class DatasetUseCheckApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||||
return {'is_using': dataset_is_using}, 200
|
return {"is_using": dataset_is_using}, 200
|
||||||
|
|
||||||
|
|
||||||
class DatasetQueryApi(Resource):
|
class DatasetQueryApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -292,51 +294,53 @@ class DatasetQueryApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
|
|
||||||
dataset_queries, total = DatasetService.get_dataset_queries(
|
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||||
dataset_id=dataset.id,
|
|
||||||
page=page,
|
|
||||||
per_page=limit
|
|
||||||
)
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
'data': marshal(dataset_queries, dataset_query_detail_fields),
|
"data": marshal(dataset_queries, dataset_query_detail_fields),
|
||||||
'has_more': len(dataset_queries) == limit,
|
"has_more": len(dataset_queries) == limit,
|
||||||
'limit': limit,
|
"limit": limit,
|
||||||
'total': total,
|
"total": total,
|
||||||
'page': page
|
"page": page,
|
||||||
}
|
}
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexingEstimateApi(Resource):
|
class DatasetIndexingEstimateApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('indexing_technique', type=str, required=True,
|
parser.add_argument(
|
||||||
|
"indexing_technique",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
nullable=True, location='json')
|
nullable=True,
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
location="json",
|
||||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
)
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
location='json')
|
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||||
|
parser.add_argument(
|
||||||
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
if args['info_list']['data_source_type'] == 'upload_file':
|
if args["info_list"]["data_source_type"] == "upload_file":
|
||||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||||
file_details = db.session.query(UploadFile).filter(
|
file_details = (
|
||||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
db.session.query(UploadFile)
|
||||||
UploadFile.id.in_(file_ids)
|
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
|
||||||
).all()
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
if file_details is None:
|
if file_details is None:
|
||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
if file_details:
|
if file_details:
|
||||||
for file_detail in file_details:
|
for file_detail in file_details:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file",
|
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||||
upload_file=file_detail,
|
|
||||||
document_model=args['doc_form']
|
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||||
notion_info_list = args['info_list']['notion_info_list']
|
notion_info_list = args["info_list"]["notion_info_list"]
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info['workspace_id']
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info['pages']:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type="notion_import",
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page['page_id'],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page['type'],
|
"notion_page_type": page["type"],
|
||||||
"tenant_id": current_user.current_tenant_id
|
"tenant_id": current_user.current_tenant_id,
|
||||||
},
|
},
|
||||||
document_model=args['doc_form']
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif args['info_list']['data_source_type'] == 'website_crawl':
|
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||||
website_info_list = args['info_list']['website_info_list']
|
website_info_list = args["info_list"]["website_info_list"]
|
||||||
for url in website_info_list['urls']:
|
for url in website_info_list["urls"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="website_crawl",
|
datasource_type="website_crawl",
|
||||||
website_info={
|
website_info={
|
||||||
"provider": website_info_list['provider'],
|
"provider": website_info_list["provider"],
|
||||||
"job_id": website_info_list['job_id'],
|
"job_id": website_info_list["job_id"],
|
||||||
"url": url,
|
"url": url,
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_user.current_tenant_id,
|
||||||
"mode": 'crawl',
|
"mode": "crawl",
|
||||||
"only_main_content": website_info_list['only_main_content']
|
"only_main_content": website_info_list["only_main_content"],
|
||||||
},
|
},
|
||||||
document_model=args['doc_form']
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Data source type not support')
|
raise ValueError("Data source type not support")
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
response = indexing_runner.indexing_estimate(
|
||||||
args['process_rule'], args['doc_form'],
|
current_user.current_tenant_id,
|
||||||
args['doc_language'], args['dataset_id'],
|
extract_settings,
|
||||||
args['indexing_technique'])
|
args["process_rule"],
|
||||||
|
args["doc_form"],
|
||||||
|
args["doc_language"],
|
||||||
|
args["dataset_id"],
|
||||||
|
args["indexing_technique"],
|
||||||
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||||
"in the Settings -> Model Provider.")
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DatasetRelatedAppListApi(Resource):
|
class DatasetRelatedAppListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource):
|
||||||
if app_model:
|
if app_model:
|
||||||
related_apps.append(app_model)
|
related_apps.append(app_model)
|
||||||
|
|
||||||
return {
|
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||||
'data': related_apps,
|
|
||||||
'total': len(related_apps)
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexingStatusApi(Resource):
|
class DatasetIndexingStatusApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
documents = db.session.query(Document).filter(
|
documents = (
|
||||||
Document.dataset_id == dataset_id,
|
db.session.query(Document)
|
||||||
Document.tenant_id == current_user.current_tenant_id
|
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
|
||||||
).all()
|
.all()
|
||||||
|
)
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
completed_segments = DocumentSegment.query.filter(
|
||||||
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.document_id == str(document.id),
|
DocumentSegment.document_id == str(document.id),
|
||||||
DocumentSegment.status != 're_segment').count()
|
DocumentSegment.status != "re_segment",
|
||||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
).count()
|
||||||
DocumentSegment.status != 're_segment').count()
|
total_segments = DocumentSegment.query.filter(
|
||||||
|
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||||
|
).count()
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
data = {
|
data = {"data": documents_status}
|
||||||
'data': documents_status
|
|
||||||
}
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class DatasetApiKeyApi(Resource):
|
class DatasetApiKeyApi(Resource):
|
||||||
max_keys = 10
|
max_keys = 10
|
||||||
token_prefix = 'dataset-'
|
token_prefix = "dataset-"
|
||||||
resource_type = 'dataset'
|
resource_type = "dataset"
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_list)
|
@marshal_with(api_key_list)
|
||||||
def get(self):
|
def get(self):
|
||||||
keys = db.session.query(ApiToken). \
|
keys = (
|
||||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
db.session.query(ApiToken)
|
||||||
all()
|
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = db.session.query(ApiToken). \
|
current_key_count = (
|
||||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
db.session.query(ApiToken)
|
||||||
count()
|
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
if current_key_count >= self.max_keys:
|
if current_key_count >= self.max_keys:
|
||||||
flask_restful.abort(
|
flask_restful.abort(
|
||||||
400,
|
400,
|
||||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
code='max_keys_exceeded'
|
code="max_keys_exceeded",
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
|
@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DatasetApiDeleteApi(Resource):
|
class DatasetApiDeleteApi(Resource):
|
||||||
resource_type = 'dataset'
|
resource_type = "dataset"
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
key = db.session.query(ApiToken). \
|
key = (
|
||||||
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
db.session.query(ApiToken)
|
||||||
ApiToken.id == api_key_id). \
|
.filter(
|
||||||
first()
|
ApiToken.tenant_id == current_user.current_tenant_id,
|
||||||
|
ApiToken.type == self.resource_type,
|
||||||
|
ApiToken.id == api_key_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
flask_restful.abort(404, message='API key not found')
|
flask_restful.abort(404, message="API key not found")
|
||||||
|
|
||||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class DatasetApiBaseUrlApi(Resource):
|
class DatasetApiBaseUrlApi(Resource):
|
||||||
|
@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {
|
return {
|
||||||
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
|
"api_base_url": (
|
||||||
else request.host_url.rstrip('/')) + '/v1'
|
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||||
|
)
|
||||||
|
+ "/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = dify_config.VECTOR_STORE
|
vector_type = dify_config.VECTOR_STORE
|
||||||
match vector_type:
|
match vector_type:
|
||||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
case (
|
||||||
|
VectorType.MILVUS
|
||||||
|
| VectorType.RELYT
|
||||||
|
| VectorType.PGVECTOR
|
||||||
|
| VectorType.TIDB_VECTOR
|
||||||
|
| VectorType.CHROMA
|
||||||
|
| VectorType.TENCENT
|
||||||
|
):
|
||||||
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
|
case (
|
||||||
|
VectorType.QDRANT
|
||||||
|
| VectorType.WEAVIATE
|
||||||
|
| VectorType.OPENSEARCH
|
||||||
|
| VectorType.ANALYTICDB
|
||||||
|
| VectorType.MYSCALE
|
||||||
|
| VectorType.ORACLE
|
||||||
|
| VectorType.ELASTICSEARCH
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
"retrieval_method": [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
|
||||||
]
|
|
||||||
}
|
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
|
||||||
return {
|
|
||||||
'retrieval_method': [
|
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||||
RetrievalMethod.HYBRID_SEARCH.value,
|
RetrievalMethod.HYBRID_SEARCH.value,
|
||||||
|
@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, vector_type):
|
def get(self, vector_type):
|
||||||
match vector_type:
|
match vector_type:
|
||||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
|
case (
|
||||||
|
VectorType.MILVUS
|
||||||
|
| VectorType.RELYT
|
||||||
|
| VectorType.TIDB_VECTOR
|
||||||
|
| VectorType.CHROMA
|
||||||
|
| VectorType.TENCENT
|
||||||
|
| VectorType.PGVECTO_RS
|
||||||
|
):
|
||||||
|
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
|
case (
|
||||||
|
VectorType.QDRANT
|
||||||
|
| VectorType.WEAVIATE
|
||||||
|
| VectorType.OPENSEARCH
|
||||||
|
| VectorType.ANALYTICDB
|
||||||
|
| VectorType.MYSCALE
|
||||||
|
| VectorType.ORACLE
|
||||||
|
| VectorType.ELASTICSEARCH
|
||||||
|
| VectorType.PGVECTOR
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
"retrieval_method": [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
|
||||||
]
|
|
||||||
}
|
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
|
|
||||||
return {
|
|
||||||
'retrieval_method': [
|
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||||
RetrievalMethod.HYBRID_SEARCH.value,
|
RetrievalMethod.HYBRID_SEARCH.value,
|
||||||
|
@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetErrorDocs(Resource):
|
class DatasetErrorDocs(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource):
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||||
|
|
||||||
return {
|
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||||
'data': [marshal(item, document_status_fields) for item in results],
|
|
||||||
'total': len(results)
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetPermissionUserListApi(Resource):
|
class DatasetPermissionUserListApi(Resource):
|
||||||
|
@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource):
|
||||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'data': partial_members_list,
|
"data": partial_members_list,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetListApi, '/datasets')
|
api.add_resource(DatasetListApi, "/datasets")
|
||||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
|
||||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
|
||||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
|
||||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
|
||||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
|
||||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
|
||||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
|
||||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
|
||||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
|
||||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
|
||||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
|
||||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=str, default=None, location='args')
|
parser.add_argument("last_id", type=str, default=None, location="args")
|
||||||
parser.add_argument('limit', type=int, default=20, location='args')
|
parser.add_argument("limit", type=int, default=20, location="args")
|
||||||
parser.add_argument('status', type=str,
|
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||||
action='append', default=[], location='args')
|
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||||
parser.add_argument('hit_count_gte', type=int,
|
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||||
default=None, location='args')
|
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||||
parser.add_argument('enabled', type=str, default='all', location='args')
|
|
||||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
last_id = args['last_id']
|
last_id = args["last_id"]
|
||||||
limit = min(args['limit'], 100)
|
limit = min(args["limit"], 100)
|
||||||
status_list = args['status']
|
status_list = args["status"]
|
||||||
hit_count_gte = args['hit_count_gte']
|
hit_count_gte = args["hit_count_gte"]
|
||||||
keyword = args['keyword']
|
keyword = args["keyword"]
|
||||||
|
|
||||||
query = DocumentSegment.query.filter(
|
query = DocumentSegment.query.filter(
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if last_id is not None:
|
if last_id is not None:
|
||||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||||
if last_segment:
|
if last_segment:
|
||||||
query = query.filter(
|
query = query.filter(DocumentSegment.position > last_segment.position)
|
||||||
DocumentSegment.position > last_segment.position)
|
|
||||||
else:
|
else:
|
||||||
return {'data': [], 'has_more': False, 'limit': limit}, 200
|
return {"data": [], "has_more": False, "limit": limit}, 200
|
||||||
|
|
||||||
if status_list:
|
if status_list:
|
||||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||||
|
@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
||||||
|
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
if args['enabled'].lower() != 'all':
|
if args["enabled"].lower() != "all":
|
||||||
if args['enabled'].lower() == 'true':
|
if args["enabled"].lower() == "true":
|
||||||
query = query.filter(DocumentSegment.enabled == True)
|
query = query.filter(DocumentSegment.enabled == True)
|
||||||
elif args['enabled'].lower() == 'false':
|
elif args["enabled"].lower() == "false":
|
||||||
query = query.filter(DocumentSegment.enabled == False)
|
query = query.filter(DocumentSegment.enabled == False)
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
|
@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||||
segments = segments[:-1]
|
segments = segments[:-1]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'data': marshal(segments, segment_fields),
|
"data": marshal(segments, segment_fields),
|
||||||
'doc_form': document.doc_form,
|
"doc_form": document.doc_form,
|
||||||
'has_more': has_more,
|
"has_more": has_more,
|
||||||
'limit': limit,
|
"limit": limit,
|
||||||
'total': total
|
"total": total,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('vector_space')
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
def patch(self, dataset_id, segment_id, action):
|
def patch(self, dataset_id, segment_id, action):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check user's model setting
|
# check user's model setting
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
|
@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
|
@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound('Segment not found.')
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
if segment.status != 'completed':
|
if segment.status != "completed":
|
||||||
raise NotFound('Segment is not completed, enable or disable function is not allowed')
|
raise NotFound("Segment is not completed, enable or disable function is not allowed")
|
||||||
|
|
||||||
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
|
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
|
||||||
cache_result = redis_client.get(document_indexing_cache_key)
|
cache_result = redis_client.get(document_indexing_cache_key)
|
||||||
if cache_result is not None:
|
if cache_result is not None:
|
||||||
raise InvalidActionError("Document is being indexed, please try again later")
|
raise InvalidActionError("Document is being indexed, please try again later")
|
||||||
|
|
||||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is not None:
|
if cache_result is not None:
|
||||||
raise InvalidActionError("Segment is being indexed, please try again later")
|
raise InvalidActionError("Segment is being indexed, please try again later")
|
||||||
|
@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
|
|
||||||
enable_segment_to_index_task.delay(segment.id)
|
enable_segment_to_index_task.delay(segment.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
elif action == "disable":
|
elif action == "disable":
|
||||||
if not segment.enabled:
|
if not segment.enabled:
|
||||||
raise InvalidActionError("Segment is already disabled.")
|
raise InvalidActionError("Segment is already disabled.")
|
||||||
|
@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||||
|
|
||||||
disable_segment_from_index_task.delay(segment.id)
|
disable_segment_from_index_task.delay(segment.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
else:
|
else:
|
||||||
raise InvalidActionError()
|
raise InvalidActionError()
|
||||||
|
|
||||||
|
@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('vector_space')
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
try:
|
try:
|
||||||
|
@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
segment = SegmentService.create_segment(args, document, dataset)
|
segment = SegmentService.create_segment(args, document, dataset)
|
||||||
return {
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
'data': marshal(segment, segment_fields),
|
|
||||||
'doc_form': document.doc_form
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('vector_space')
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check user's model setting
|
# check user's model setting
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
|
@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
).first()
|
).first()
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound('Segment not found.')
|
raise NotFound("Segment not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||||
return {
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
'data': marshal(segment, segment_fields),
|
|
||||||
'doc_form': document.doc_form
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check user's model setting
|
# check user's model setting
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
).first()
|
).first()
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound('Segment not found.')
|
raise NotFound("Segment not found.")
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
SegmentService.delete_segment(segment, document, dataset)
|
SegmentService.delete_segment(segment, document, dataset)
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('vector_space')
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
# check file type
|
# check file type
|
||||||
if not file.filename.endswith('.csv'):
|
if not file.filename.endswith(".csv"):
|
||||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
df = pd.read_csv(file)
|
df = pd.read_csv(file)
|
||||||
result = []
|
result = []
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
if document.doc_form == 'qa_model':
|
if document.doc_form == "qa_model":
|
||||||
data = {'content': row[0], 'answer': row[1]}
|
data = {"content": row[0], "answer": row[1]}
|
||||||
else:
|
else:
|
||||||
data = {'content': row[0]}
|
data = {"content": row[0]}
|
||||||
result.append(data)
|
result.append(data)
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
raise ValueError("The CSV file is empty.")
|
raise ValueError("The CSV file is empty.")
|
||||||
# async job
|
# async job
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
|
||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
redis_client.setnx(indexing_cache_key, "waiting")
|
||||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
batch_create_segment_to_index_task.delay(
|
||||||
current_user.current_tenant_id, current_user.id)
|
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'error': str(e)}, 500
|
return {"error": str(e)}, 500
|
||||||
return {
|
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||||
'job_id': job_id,
|
|
||||||
'job_status': 'waiting'
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id):
|
def get(self, job_id):
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job is not exist.")
|
raise ValueError("The job is not exist.")
|
||||||
|
|
||||||
return {
|
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||||
'job_id': job_id,
|
|
||||||
'job_status': cache_result.decode()
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetDocumentSegmentListApi,
|
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
|
||||||
api.add_resource(DatasetDocumentSegmentApi,
|
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
api.add_resource(
|
||||||
api.add_resource(DatasetDocumentSegmentAddApi,
|
DatasetDocumentSegmentUpdateApi,
|
||||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
)
|
||||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
api.add_resource(
|
||||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
DatasetDocumentSegmentBatchImportApi,
|
||||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||||
'/datasets/batch_import_status/<uuid:job_id>')
|
"/datasets/batch_import_status/<uuid:job_id>",
|
||||||
|
)
|
||||||
|
|
|
@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class NoFileUploadedError(BaseHTTPException):
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_file_uploaded'
|
error_code = "no_file_uploaded"
|
||||||
description = "Please upload your file."
|
description = "Please upload your file."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = 'too_many_files'
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class FileTooLargeError(BaseHTTPException):
|
class FileTooLargeError(BaseHTTPException):
|
||||||
error_code = 'file_too_large'
|
error_code = "file_too_large"
|
||||||
description = "File size exceeded. {message}"
|
description = "File size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||||
error_code = 'high_quality_dataset_only'
|
error_code = "high_quality_dataset_only"
|
||||||
description = "Current operation only supports 'high-quality' datasets."
|
description = "Current operation only supports 'high-quality' datasets."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DatasetNotInitializedError(BaseHTTPException):
|
class DatasetNotInitializedError(BaseHTTPException):
|
||||||
error_code = 'dataset_not_initialized'
|
error_code = "dataset_not_initialized"
|
||||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||||
error_code = 'archived_document_immutable'
|
error_code = "archived_document_immutable"
|
||||||
description = "The archived document is not editable."
|
description = "The archived document is not editable."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class DatasetNameDuplicateError(BaseHTTPException):
|
class DatasetNameDuplicateError(BaseHTTPException):
|
||||||
error_code = 'dataset_name_duplicate'
|
error_code = "dataset_name_duplicate"
|
||||||
description = "The dataset name already exists. Please modify your dataset name."
|
description = "The dataset name already exists. Please modify your dataset name."
|
||||||
code = 409
|
code = 409
|
||||||
|
|
||||||
|
|
||||||
class InvalidActionError(BaseHTTPException):
|
class InvalidActionError(BaseHTTPException):
|
||||||
error_code = 'invalid_action'
|
error_code = "invalid_action"
|
||||||
description = "Invalid action."
|
description = "Invalid action."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||||
error_code = 'document_already_finished'
|
error_code = "document_already_finished"
|
||||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingError(BaseHTTPException):
|
class DocumentIndexingError(BaseHTTPException):
|
||||||
error_code = 'document_indexing'
|
error_code = "document_indexing"
|
||||||
description = "The document is being processed and cannot be edited."
|
description = "The document is being processed and cannot be edited."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class InvalidMetadataError(BaseHTTPException):
|
class InvalidMetadataError(BaseHTTPException):
|
||||||
error_code = 'invalid_metadata'
|
error_code = "invalid_metadata"
|
||||||
description = "The metadata content is incorrect. Please check and verify."
|
description = "The metadata content is incorrect. Please check and verify."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class WebsiteCrawlError(BaseHTTPException):
|
class WebsiteCrawlError(BaseHTTPException):
|
||||||
error_code = 'crawl_failed'
|
error_code = "crawl_failed"
|
||||||
description = "{message}"
|
description = "{message}"
|
||||||
code = 500
|
code = 500
|
||||||
|
|
||||||
|
|
||||||
class DatasetInUseError(BaseHTTPException):
|
class DatasetInUseError(BaseHTTPException):
|
||||||
error_code = 'dataset_in_use'
|
error_code = "dataset_in_use"
|
||||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||||
code = 409
|
code = 409
|
||||||
|
|
||||||
|
|
||||||
class IndexingEstimateError(BaseHTTPException):
|
class IndexingEstimateError(BaseHTTPException):
|
||||||
error_code = 'indexing_estimate_error'
|
error_code = "indexing_estimate_error"
|
||||||
description = "Knowledge indexing estimate failed: {message}"
|
description = "Knowledge indexing estimate failed: {message}"
|
||||||
code = 500
|
code = 500
|
||||||
|
|
|
@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
||||||
class FileApi(Resource):
|
class FileApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -31,23 +30,22 @@ class FileApi(Resource):
|
||||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||||
return {
|
return {
|
||||||
'file_size_limit': file_size_limit,
|
"file_size_limit": file_size_limit,
|
||||||
'batch_count_limit': batch_count_limit,
|
"batch_count_limit": batch_count_limit,
|
||||||
'image_file_size_limit': image_file_size_limit
|
"image_file_size_limit": image_file_size_limit,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(file_fields)
|
@marshal_with(file_fields)
|
||||||
@cloud_edition_billing_resource_check(resource='documents')
|
@cloud_edition_billing_resource_check(resource="documents")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
|
@ -69,7 +67,7 @@ class FilePreviewApi(Resource):
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
text = FileService.get_file_preview(file_id)
|
text = FileService.get_file_preview(file_id)
|
||||||
return {'content': text}
|
return {"content": text}
|
||||||
|
|
||||||
|
|
||||||
class FileSupportTypeApi(Resource):
|
class FileSupportTypeApi(Resource):
|
||||||
|
@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
etl_type = dify_config.ETL_TYPE
|
etl_type = dify_config.ETL_TYPE
|
||||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||||
return {'allowed_extensions': allowed_extensions}
|
return {"allowed_extensions": allowed_extensions}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, '/files/upload')
|
api.add_resource(FileApi, "/files/upload")
|
||||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||||
api.add_resource(FileSupportTypeApi, '/files/support-type')
|
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||||
|
|
|
@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
|
|
||||||
class HitTestingApi(Resource):
|
class HitTestingApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -46,8 +45,8 @@ class HitTestingApi(Resource):
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('query', type=str, location='json')
|
parser.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
@ -55,13 +54,13 @@ class HitTestingApi(Resource):
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.retrieve(
|
response = HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args['query'],
|
query=args["query"],
|
||||||
account=current_user,
|
account=current_user,
|
||||||
retrieval_model=args['retrieval_model'],
|
retrieval_model=args["retrieval_model"],
|
||||||
limit=10
|
limit=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
except services.errors.index.IndexNotInitializedError:
|
except services.errors.index.IndexNotInitializedError:
|
||||||
raise DatasetNotInitializedError()
|
raise DatasetNotInitializedError()
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
|
@ -73,7 +72,8 @@ class HitTestingApi(Resource):
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -83,4 +83,4 @@ class HitTestingApi(Resource):
|
||||||
raise InternalServerError(str(e))
|
raise InternalServerError(str(e))
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
|
|
|
@ -9,16 +9,14 @@ from services.website_service import WebsiteService
|
||||||
|
|
||||||
|
|
||||||
class WebsiteCrawlApi(Resource):
|
class WebsiteCrawlApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
|
||||||
required=True, nullable=True, location='json')
|
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
WebsiteService.document_create_args_validate(args)
|
WebsiteService.document_create_args_validate(args)
|
||||||
# crawl url
|
# crawl url
|
||||||
|
@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id: str):
|
def get(self, job_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# get crawl status
|
# get crawl status
|
||||||
try:
|
try:
|
||||||
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
result = WebsiteService.get_crawl_status(job_id, args["provider"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WebsiteCrawlError(str(e))
|
raise WebsiteCrawlError(str(e))
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
api.add_resource(WebsiteCrawlApi, "/website/crawl")
|
||||||
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")
|
||||||
|
|
|
@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class AlreadySetupError(BaseHTTPException):
|
class AlreadySetupError(BaseHTTPException):
|
||||||
error_code = 'already_setup'
|
error_code = "already_setup"
|
||||||
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
|
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class NotSetupError(BaseHTTPException):
|
class NotSetupError(BaseHTTPException):
|
||||||
error_code = 'not_setup'
|
error_code = "not_setup"
|
||||||
description = "Dify has not been initialized and installed yet. " \
|
description = (
|
||||||
|
"Dify has not been initialized and installed yet. "
|
||||||
"Please proceed with the initialization and installation process first."
|
"Please proceed with the initialization and installation process first."
|
||||||
|
)
|
||||||
code = 401
|
code = 401
|
||||||
|
|
||||||
|
|
||||||
class NotInitValidateError(BaseHTTPException):
|
class NotInitValidateError(BaseHTTPException):
|
||||||
error_code = 'not_init_validated'
|
error_code = "not_init_validated"
|
||||||
description = "Init validation has not been completed yet. " \
|
description = (
|
||||||
"Please proceed with the init validation process first."
|
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
|
||||||
|
)
|
||||||
code = 401
|
code = 401
|
||||||
|
|
||||||
|
|
||||||
class InitValidateFailedError(BaseHTTPException):
|
class InitValidateFailedError(BaseHTTPException):
|
||||||
error_code = 'init_validate_failed'
|
error_code = "init_validate_failed"
|
||||||
description = "Init validation failed. Please check the password and try again."
|
description = "Init validation failed. Please check the password and try again."
|
||||||
code = 401
|
code = 401
|
||||||
|
|
||||||
|
|
||||||
class AccountNotLinkTenantError(BaseHTTPException):
|
class AccountNotLinkTenantError(BaseHTTPException):
|
||||||
error_code = 'account_not_link_tenant'
|
error_code = "account_not_link_tenant"
|
||||||
description = "Account not link tenant."
|
description = "Account not link tenant."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class AlreadyActivateError(BaseHTTPException):
|
class AlreadyActivateError(BaseHTTPException):
|
||||||
error_code = 'already_activate'
|
error_code = "already_activate"
|
||||||
description = "Auth Token is invalid or account already activated, please check again."
|
description = "Auth Token is invalid or account already activated, please check again."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
|
@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AudioService.transcript_asr(
|
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||||
app_model=app_model,
|
|
||||||
file=file,
|
|
||||||
end_user=None
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument('text', type=str, location='json')
|
parser.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument("streaming", type=bool, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id', None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get('text', None)
|
text = args.get("text", None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (
|
||||||
|
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict
|
||||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
):
|
||||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
|
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
voice = (
|
||||||
|
args.get("voice")
|
||||||
|
if args.get("voice")
|
||||||
|
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||||
app_model=app_model,
|
|
||||||
message_id=message_id,
|
|
||||||
voice=voice,
|
|
||||||
text=text
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
|
@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource):
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||||
# endpoint='installed_app_text_with_message_id')
|
# endpoint='installed_app_text_with_message_id')
|
||||||
|
|
|
@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
# define completion api for user
|
# define completion api for user
|
||||||
class CompletionApi(InstalledAppResource):
|
class CompletionApi(InstalledAppResource):
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, location='json', default='')
|
parser.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||||
user=current_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource):
|
||||||
class CompletionStopApi(InstalledAppResource):
|
class CompletionStopApi(InstalledAppResource):
|
||||||
def post(self, installed_app, task_id):
|
def post(self, installed_app, task_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ChatApi(InstalledAppResource):
|
class ChatApi(InstalledAppResource):
|
||||||
|
@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, required=True, location='json')
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
user=current_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
|
api.add_resource(
|
||||||
api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
|
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||||
api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
|
)
|
||||||
api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
|
api.add_resource(
|
||||||
|
CompletionStopApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||||
|
endpoint="installed_app_stop_completion",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ChatStopApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||||
|
endpoint="installed_app_stop_chat_completion",
|
||||||
|
)
|
||||||
|
|
|
@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
|
|
||||||
class ConversationListApi(InstalledAppResource):
|
class ConversationListApi(InstalledAppResource):
|
||||||
|
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pinned = None
|
pinned = None
|
||||||
if 'pinned' in args and args['pinned'] is not None:
|
if "pinned" in args and args["pinned"] is not None:
|
||||||
pinned = True if args['pinned'] == 'true' else False
|
pinned = True if args["pinned"] == "true" else False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return WebConversationService.pagination_by_last_id(
|
return WebConversationService.pagination_by_last_id(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
last_id=args['last_id'],
|
last_id=args["last_id"],
|
||||||
limit=args['limit'],
|
limit=args["limit"],
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
pinned=pinned,
|
pinned=pinned,
|
||||||
)
|
)
|
||||||
|
@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource):
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenameApi(InstalledAppResource):
|
class ConversationRenameApi(InstalledAppResource):
|
||||||
|
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, installed_app, c_id):
|
def post(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource):
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=False, location='json')
|
parser.add_argument("name", type=str, required=False, location="json")
|
||||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(
|
||||||
app_model,
|
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||||
conversation_id,
|
|
||||||
current_user,
|
|
||||||
args['name'],
|
|
||||||
args['auto_generate']
|
|
||||||
)
|
)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
class ConversationPinApi(InstalledAppResource):
|
class ConversationPinApi(InstalledAppResource):
|
||||||
|
|
||||||
def patch(self, installed_app, c_id):
|
def patch(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
|
api.add_resource(
|
||||||
api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
|
ConversationRenameApi,
|
||||||
api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||||
api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
|
endpoint="installed_app_conversation_rename",
|
||||||
api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ConversationApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||||
|
endpoint="installed_app_conversation",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ConversationPinApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||||
|
endpoint="installed_app_conversation_pin",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ConversationUnPinApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||||
|
endpoint="installed_app_conversation_unpin",
|
||||||
|
)
|
||||||
|
|
|
@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class NotCompletionAppError(BaseHTTPException):
|
class NotCompletionAppError(BaseHTTPException):
|
||||||
error_code = 'not_completion_app'
|
error_code = "not_completion_app"
|
||||||
description = "Not Completion App"
|
description = "Not Completion App"
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotChatAppError(BaseHTTPException):
|
class NotChatAppError(BaseHTTPException):
|
||||||
error_code = 'not_chat_app'
|
error_code = "not_chat_app"
|
||||||
description = "App mode is invalid."
|
description = "App mode is invalid."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotWorkflowAppError(BaseHTTPException):
|
class NotWorkflowAppError(BaseHTTPException):
|
||||||
error_code = 'not_workflow_app'
|
error_code = "not_workflow_app"
|
||||||
description = "Only support workflow app."
|
description = "Only support workflow app."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
error_code = "app_suggested_questions_after_answer_disabled"
|
||||||
description = "Function Suggested questions after answer disabled."
|
description = "Function Suggested questions after answer disabled."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
|
@ -21,72 +21,71 @@ class InstalledAppsListApi(Resource):
|
||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
installed_apps = db.session.query(InstalledApp).filter(
|
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||||
InstalledApp.tenant_id == current_tenant_id
|
|
||||||
).all()
|
|
||||||
|
|
||||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||||
installed_apps = [
|
installed_apps = [
|
||||||
{
|
{
|
||||||
'id': installed_app.id,
|
"id": installed_app.id,
|
||||||
'app': installed_app.app,
|
"app": installed_app.app,
|
||||||
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
|
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||||
'is_pinned': installed_app.is_pinned,
|
"is_pinned": installed_app.is_pinned,
|
||||||
'last_used_at': installed_app.last_used_at,
|
"last_used_at": installed_app.last_used_at,
|
||||||
'editable': current_user.role in ["owner", "admin"],
|
"editable": current_user.role in ["owner", "admin"],
|
||||||
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
|
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||||
}
|
}
|
||||||
for installed_app in installed_apps
|
for installed_app in installed_apps
|
||||||
]
|
]
|
||||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
installed_apps.sort(
|
||||||
app['last_used_at'] is None,
|
key=lambda app: (
|
||||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
-app["is_pinned"],
|
||||||
|
app["last_used_at"] is None,
|
||||||
|
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {'installed_apps': installed_apps}
|
return {"installed_apps": installed_apps}
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('apps')
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
|
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound('App not found')
|
raise NotFound("App not found")
|
||||||
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
current_tenant_id = current_user.current_tenant_id
|
||||||
app = db.session.query(App).filter(
|
app = db.session.query(App).filter(App.id == args["app_id"]).first()
|
||||||
App.id == args['app_id']
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
raise NotFound('App not found')
|
raise NotFound("App not found")
|
||||||
|
|
||||||
if not app.is_public:
|
if not app.is_public:
|
||||||
raise Forbidden('You can\'t install a non-public app')
|
raise Forbidden("You can't install a non-public app")
|
||||||
|
|
||||||
installed_app = InstalledApp.query.filter(and_(
|
installed_app = InstalledApp.query.filter(
|
||||||
InstalledApp.app_id == args['app_id'],
|
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
|
||||||
InstalledApp.tenant_id == current_tenant_id
|
).first()
|
||||||
)).first()
|
|
||||||
|
|
||||||
if installed_app is None:
|
if installed_app is None:
|
||||||
# todo: position
|
# todo: position
|
||||||
recommended_app.install_count += 1
|
recommended_app.install_count += 1
|
||||||
|
|
||||||
new_installed_app = InstalledApp(
|
new_installed_app = InstalledApp(
|
||||||
app_id=args['app_id'],
|
app_id=args["app_id"],
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_owner_tenant_id=app.tenant_id,
|
app_owner_tenant_id=app.tenant_id,
|
||||||
is_pinned=False,
|
is_pinned=False,
|
||||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
db.session.add(new_installed_app)
|
db.session.add(new_installed_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'message': 'App installed successfully'}
|
return {"message": "App installed successfully"}
|
||||||
|
|
||||||
|
|
||||||
class InstalledAppApi(InstalledAppResource):
|
class InstalledAppApi(InstalledAppResource):
|
||||||
|
@ -94,30 +93,31 @@ class InstalledAppApi(InstalledAppResource):
|
||||||
update and delete an installed app
|
update and delete an installed app
|
||||||
use InstalledAppResource to apply default decorators and get installed_app
|
use InstalledAppResource to apply default decorators and get installed_app
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||||
raise BadRequest('You can\'t uninstall an app owned by the current tenant')
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
db.session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success', 'message': 'App uninstalled successfully'}
|
return {"result": "success", "message": "App uninstalled successfully"}
|
||||||
|
|
||||||
def patch(self, installed_app):
|
def patch(self, installed_app):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('is_pinned', type=inputs.boolean)
|
parser.add_argument("is_pinned", type=inputs.boolean)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
commit_args = False
|
commit_args = False
|
||||||
if 'is_pinned' in args:
|
if "is_pinned" in args:
|
||||||
installed_app.is_pinned = args['is_pinned']
|
installed_app.is_pinned = args["is_pinned"]
|
||||||
commit_args = True
|
commit_args = True
|
||||||
|
|
||||||
if commit_args:
|
if commit_args:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success', 'message': 'App info updated successfully'}
|
return {"result": "success", "message": "App info updated successfully"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(InstalledAppsListApi, '/installed-apps')
|
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
||||||
api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')
|
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
||||||
|
|
|
@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
return MessageService.pagination_by_first_id(
|
||||||
args['conversation_id'], args['first_id'], args['limit'])
|
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
except services.errors.message.FirstMessageNotExistsError:
|
except services.errors.message.FirstMessageNotExistsError:
|
||||||
raise NotFound("First Message Not Exists.")
|
raise NotFound("First Message Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackApi(InstalledAppResource):
|
class MessageFeedbackApi(InstalledAppResource):
|
||||||
def post(self, installed_app, message_id):
|
def post(self, installed_app, message_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||||
except services.errors.message.MessageNotExistsError:
|
except services.errors.message.MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
parser.add_argument(
|
||||||
|
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
|
@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
user=current_user,
|
user=current_user,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
streaming=streaming
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
|
@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
user=current_user,
|
|
||||||
message_id=message_id,
|
|
||||||
invoke_from=InvokeFrom.EXPLORE
|
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message not found")
|
raise NotFound("Message not found")
|
||||||
|
@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
return {'data': questions}
|
return {"data": questions}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
api.add_resource(
|
||||||
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
MessageFeedbackApi,
|
||||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||||
|
endpoint="installed_app_message_feedback",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
MessageMoreLikeThisApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||||
|
endpoint="installed_app_more_like_this",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
MessageSuggestedQuestionApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||||
|
endpoint="installed_app_suggested_question",
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from flask_restful import fields, marshal_with
|
from flask_restful import fields, marshal_with
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
@ -11,33 +10,32 @@ from services.app_service import AppService
|
||||||
|
|
||||||
class AppParameterApi(InstalledAppResource):
|
class AppParameterApi(InstalledAppResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
variable_fields = {
|
variable_fields = {
|
||||||
'key': fields.String,
|
"key": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'description': fields.String,
|
"description": fields.String,
|
||||||
'type': fields.String,
|
"type": fields.String,
|
||||||
'default': fields.String,
|
"default": fields.String,
|
||||||
'max_length': fields.Integer,
|
"max_length": fields.Integer,
|
||||||
'options': fields.List(fields.String)
|
"options": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
system_parameters_fields = {
|
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||||
'image_file_size_limit': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
parameters_fields = {
|
parameters_fields = {
|
||||||
'opening_statement': fields.String,
|
"opening_statement": fields.String,
|
||||||
'suggested_questions': fields.Raw,
|
"suggested_questions": fields.Raw,
|
||||||
'suggested_questions_after_answer': fields.Raw,
|
"suggested_questions_after_answer": fields.Raw,
|
||||||
'speech_to_text': fields.Raw,
|
"speech_to_text": fields.Raw,
|
||||||
'text_to_speech': fields.Raw,
|
"text_to_speech": fields.Raw,
|
||||||
'retriever_resource': fields.Raw,
|
"retriever_resource": fields.Raw,
|
||||||
'annotation_reply': fields.Raw,
|
"annotation_reply": fields.Raw,
|
||||||
'more_like_this': fields.Raw,
|
"more_like_this": fields.Raw,
|
||||||
'user_input_form': fields.Raw,
|
"user_input_form": fields.Raw,
|
||||||
'sensitive_word_avoidance': fields.Raw,
|
"sensitive_word_avoidance": fields.Raw,
|
||||||
'file_upload': fields.Raw,
|
"file_upload": fields.Raw,
|
||||||
'system_parameters': fields.Nested(system_parameters_fields)
|
"system_parameters": fields.Nested(system_parameters_fields),
|
||||||
}
|
}
|
||||||
|
|
||||||
@marshal_with(parameters_fields)
|
@marshal_with(parameters_fields)
|
||||||
|
@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource):
|
||||||
app_model_config = app_model.app_model_config
|
app_model_config = app_model.app_model_config
|
||||||
features_dict = app_model_config.to_dict()
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
user_input_form = features_dict.get('user_input_form', [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'opening_statement': features_dict.get('opening_statement'),
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
'suggested_questions': features_dict.get('suggested_questions', []),
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
|
"suggested_questions_after_answer": features_dict.get(
|
||||||
{"enabled": False}),
|
"suggested_questions_after_answer", {"enabled": False}
|
||||||
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
|
),
|
||||||
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
'user_input_form': user_input_form,
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
|
"user_input_form": user_input_form,
|
||||||
{"enabled": False, "type": "", "configs": []}),
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
'file_upload': features_dict.get('file_upload', {"image": {
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"number_limits": 3,
|
"number_limits": 3,
|
||||||
"detail": "high",
|
"detail": "high",
|
||||||
"transfer_methods": ["remote_url", "local_file"]
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
}}),
|
|
||||||
'system_parameters': {
|
|
||||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
|
api.add_resource(
|
||||||
endpoint='installed_app_parameters')
|
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
||||||
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
|
)
|
||||||
|
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||||
|
|
|
@ -8,28 +8,28 @@ from libs.login import login_required
|
||||||
from services.recommended_app_service import RecommendedAppService
|
from services.recommended_app_service import RecommendedAppService
|
||||||
|
|
||||||
app_fields = {
|
app_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'mode': fields.String,
|
"mode": fields.String,
|
||||||
'icon': fields.String,
|
"icon": fields.String,
|
||||||
'icon_background': fields.String
|
"icon_background": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
recommended_app_fields = {
|
recommended_app_fields = {
|
||||||
'app': fields.Nested(app_fields, attribute='app'),
|
"app": fields.Nested(app_fields, attribute="app"),
|
||||||
'app_id': fields.String,
|
"app_id": fields.String,
|
||||||
'description': fields.String(attribute='description'),
|
"description": fields.String(attribute="description"),
|
||||||
'copyright': fields.String,
|
"copyright": fields.String,
|
||||||
'privacy_policy': fields.String,
|
"privacy_policy": fields.String,
|
||||||
'custom_disclaimer': fields.String,
|
"custom_disclaimer": fields.String,
|
||||||
'category': fields.String,
|
"category": fields.String,
|
||||||
'position': fields.Integer,
|
"position": fields.Integer,
|
||||||
'is_listed': fields.Boolean
|
"is_listed": fields.Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
recommended_app_list_fields = {
|
recommended_app_list_fields = {
|
||||||
'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
|
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
|
||||||
'categories': fields.List(fields.String)
|
"categories": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
# language args
|
# language args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('language', type=str, location='args')
|
parser.add_argument("language", type=str, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.get('language') and args.get('language') in languages:
|
if args.get("language") and args.get("language") in languages:
|
||||||
language_prefix = args.get('language')
|
language_prefix = args.get("language")
|
||||||
elif current_user and current_user.interface_language:
|
elif current_user and current_user.interface_language:
|
||||||
language_prefix = current_user.interface_language
|
language_prefix = current_user.interface_language
|
||||||
else:
|
else:
|
||||||
|
@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
|
||||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(RecommendedAppListApi, '/explore/apps')
|
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
||||||
api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')
|
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
||||||
|
|
|
@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
feedback_fields = {
|
feedback_fields = {"rating": fields.String}
|
||||||
'rating': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
message_fields = {
|
message_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'inputs': fields.Raw,
|
"inputs": fields.Raw,
|
||||||
'query': fields.String,
|
"query": fields.String,
|
||||||
'answer': fields.String,
|
"answer": fields.String,
|
||||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
'created_at': TimestampField
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListApi(InstalledAppResource):
|
class SavedMessageListApi(InstalledAppResource):
|
||||||
saved_message_infinite_scroll_pagination_fields = {
|
saved_message_infinite_scroll_pagination_fields = {
|
||||||
'limit': fields.Integer,
|
"limit": fields.Integer,
|
||||||
'has_more': fields.Boolean,
|
"has_more": fields.Boolean,
|
||||||
'data': fields.List(fields.Nested(message_fields))
|
"data": fields.List(fields.Nested(message_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
|
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
|
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
SavedMessageService.save(app_model, current_user, args['message_id'])
|
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageApi(InstalledAppResource):
|
class SavedMessageApi(InstalledAppResource):
|
||||||
|
@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource):
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
SavedMessageService.delete(app_model, current_user, message_id)
|
SavedMessageService.delete(app_model, current_user, message_id)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
|
api.add_resource(
|
||||||
api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
|
SavedMessageListApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
||||||
|
endpoint="installed_app_saved_messages",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
SavedMessageApi,
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
||||||
|
endpoint="installed_app_saved_message",
|
||||||
|
)
|
||||||
|
|
|
@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
user=current_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
|
||||||
streaming=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
||||||
return {
|
return {"result": "success"}
|
||||||
"result": "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
|
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||||
api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
|
api.add_resource(
|
||||||
|
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||||
|
)
|
||||||
|
|
|
@ -14,29 +14,33 @@ def installed_app_required(view=None):
|
||||||
def decorator(view):
|
def decorator(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if not kwargs.get('installed_app_id'):
|
if not kwargs.get("installed_app_id"):
|
||||||
raise ValueError('missing installed_app_id in path parameters')
|
raise ValueError("missing installed_app_id in path parameters")
|
||||||
|
|
||||||
installed_app_id = kwargs.get('installed_app_id')
|
installed_app_id = kwargs.get("installed_app_id")
|
||||||
installed_app_id = str(installed_app_id)
|
installed_app_id = str(installed_app_id)
|
||||||
|
|
||||||
del kwargs['installed_app_id']
|
del kwargs["installed_app_id"]
|
||||||
|
|
||||||
installed_app = db.session.query(InstalledApp).filter(
|
installed_app = (
|
||||||
InstalledApp.id == str(installed_app_id),
|
db.session.query(InstalledApp)
|
||||||
InstalledApp.tenant_id == current_user.current_tenant_id
|
.filter(
|
||||||
).first()
|
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if installed_app is None:
|
if installed_app is None:
|
||||||
raise NotFound('Installed app not found')
|
raise NotFound("Installed app not found")
|
||||||
|
|
||||||
if not installed_app.app:
|
if not installed_app.app:
|
||||||
db.session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
raise NotFound('Installed app not found')
|
raise NotFound("Installed app not found")
|
||||||
|
|
||||||
return view(installed_app, *args, **kwargs)
|
return view(installed_app, *args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
if view:
|
if view:
|
||||||
|
|
|
@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
|
||||||
|
|
||||||
class CodeBasedExtensionAPI(Resource):
|
class CodeBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('module', type=str, required=True, location='args')
|
parser.add_argument("module", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||||
'module': args['module'],
|
|
||||||
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class APIBasedExtensionAPI(Resource):
|
class APIBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource):
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=args['name'],
|
name=args["name"],
|
||||||
api_endpoint=args['api_endpoint'],
|
api_endpoint=args["api_endpoint"],
|
||||||
api_key=args['api_key']
|
api_key=args["api_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data)
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
||||||
|
|
||||||
class APIBasedExtensionDetailAPI(Resource):
|
class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
extension_data_from_db.name = args['name']
|
extension_data_from_db.name = args["name"]
|
||||||
extension_data_from_db.api_endpoint = args['api_endpoint']
|
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||||
|
|
||||||
if args['api_key'] != HIDDEN_VALUE:
|
if args["api_key"] != HIDDEN_VALUE:
|
||||||
extension_data_from_db.api_key = args['api_key']
|
extension_data_from_db.api_key = args["api_key"]
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data_from_db)
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
|
@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
APIBasedExtensionService.delete(extension_data_from_db)
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
|
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
|
||||||
|
|
||||||
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
|
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
|
||||||
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
|
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")
|
||||||
|
|
|
@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record
|
||||||
|
|
||||||
|
|
||||||
class FeatureApi(Resource):
|
class FeatureApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -24,5 +23,5 @@ class SystemFeatureApi(Resource):
|
||||||
return FeatureService.get_system_features().model_dump()
|
return FeatureService.get_system_features().model_dump()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FeatureApi, '/features')
|
api.add_resource(FeatureApi, "/features")
|
||||||
api.add_resource(SystemFeatureApi, '/system-features')
|
api.add_resource(SystemFeatureApi, "/system-features")
|
||||||
|
|
|
@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
|
||||||
class InitValidateAPI(Resource):
|
class InitValidateAPI(Resource):
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
init_status = get_init_validate_status()
|
init_status = get_init_validate_status()
|
||||||
if init_status:
|
if init_status:
|
||||||
return { 'status': 'finished' }
|
return {"status": "finished"}
|
||||||
return {'status': 'not_started' }
|
return {"status": "not_started"}
|
||||||
|
|
||||||
@only_edition_self_hosted
|
@only_edition_self_hosted
|
||||||
def post(self):
|
def post(self):
|
||||||
|
@ -29,22 +28,23 @@ class InitValidateAPI(Resource):
|
||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('password', type=str_len(30),
|
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||||
required=True, location='json')
|
input_password = parser.parse_args()["password"]
|
||||||
input_password = parser.parse_args()['password']
|
|
||||||
|
|
||||||
if input_password != os.environ.get('INIT_PASSWORD'):
|
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||||
session['is_init_validated'] = False
|
session["is_init_validated"] = False
|
||||||
raise InitValidateFailedError()
|
raise InitValidateFailedError()
|
||||||
|
|
||||||
session['is_init_validated'] = True
|
session["is_init_validated"] = True
|
||||||
return {'result': 'success'}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
def get_init_validate_status():
|
def get_init_validate_status():
|
||||||
if dify_config.EDITION == 'SELF_HOSTED':
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
if os.environ.get('INIT_PASSWORD'):
|
if os.environ.get("INIT_PASSWORD"):
|
||||||
return session.get('is_init_validated') or DifySetup.query.first()
|
return session.get("is_init_validated") or DifySetup.query.first()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
api.add_resource(InitValidateAPI, '/init')
|
|
||||||
|
api.add_resource(InitValidateAPI, "/init")
|
||||||
|
|
|
@ -4,14 +4,11 @@ from controllers.console import api
|
||||||
|
|
||||||
|
|
||||||
class PingApi(Resource):
|
class PingApi(Resource):
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""
|
"""
|
||||||
For connection health check
|
For connection health check
|
||||||
"""
|
"""
|
||||||
return {
|
return {"result": "pong"}
|
||||||
"result": "pong"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(PingApi, '/ping')
|
api.add_resource(PingApi, "/ping")
|
||||||
|
|
|
@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
|
||||||
class SetupApi(Resource):
|
class SetupApi(Resource):
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
if dify_config.EDITION == 'SELF_HOSTED':
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
setup_status = get_setup_status()
|
setup_status = get_setup_status()
|
||||||
if setup_status:
|
if setup_status:
|
||||||
return {
|
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||||
'step': 'finished',
|
return {"step": "not_started"}
|
||||||
'setup_at': setup_status.setup_at.isoformat()
|
return {"step": "finished"}
|
||||||
}
|
|
||||||
return {'step': 'not_started'}
|
|
||||||
return {'step': 'finished'}
|
|
||||||
|
|
||||||
@only_edition_self_hosted
|
@only_edition_self_hosted
|
||||||
def post(self):
|
def post(self):
|
||||||
|
@ -43,23 +39,17 @@ class SetupApi(Resource):
|
||||||
raise NotInitValidateError()
|
raise NotInitValidateError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('email', type=email,
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
required=True, location='json')
|
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||||
parser.add_argument('name', type=str_len(
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
30), required=True, location='json')
|
|
||||||
parser.add_argument('password', type=valid_password,
|
|
||||||
required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args['email'],
|
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
||||||
name=args['name'],
|
|
||||||
password=args['password'],
|
|
||||||
ip_address=get_remote_ip(request)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
def setup_required(view):
|
def setup_required(view):
|
||||||
|
@ -78,9 +68,10 @@ def setup_required(view):
|
||||||
|
|
||||||
|
|
||||||
def get_setup_status():
|
def get_setup_status():
|
||||||
if dify_config.EDITION == 'SELF_HOSTED':
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
return DifySetup.query.first()
|
return DifySetup.query.first()
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
api.add_resource(SetupApi, '/setup')
|
|
||||||
|
api.add_resource(SetupApi, "/setup")
|
||||||
|
|
|
@ -14,19 +14,18 @@ from services.tag_service import TagService
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
raise ValueError('Name must be between 1 to 50 characters.')
|
raise ValueError("Name must be between 1 to 50 characters.")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
class TagListApi(Resource):
|
class TagListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(tag_fields)
|
@marshal_with(tag_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
tag_type = request.args.get('type', type=str)
|
tag_type = request.args.get("type", type=str)
|
||||||
keyword = request.args.get('keyword', default=None, type=str)
|
keyword = request.args.get("keyword", default=None, type=str)
|
||||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||||
|
|
||||||
return tags, 200
|
return tags, 200
|
||||||
|
@ -40,28 +39,21 @@ class TagListApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False, required=True,
|
parser.add_argument(
|
||||||
help='Name must be between 1 to 50 characters.',
|
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||||
type=_validate_name)
|
)
|
||||||
parser.add_argument('type', type=str, location='json',
|
parser.add_argument(
|
||||||
choices=Tag.TAG_TYPE_LIST,
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
nullable=True,
|
)
|
||||||
help='Invalid tag type.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tag = TagService.save_tags(args)
|
tag = TagService.save_tags(args)
|
||||||
|
|
||||||
response = {
|
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||||
'id': tag.id,
|
|
||||||
'name': tag.name,
|
|
||||||
'type': tag.type,
|
|
||||||
'binding_count': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
class TagUpdateDeleteApi(Resource):
|
class TagUpdateDeleteApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False, required=True,
|
parser.add_argument(
|
||||||
help='Name must be between 1 to 50 characters.',
|
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||||
type=_validate_name)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tag = TagService.update_tags(args, tag_id)
|
tag = TagService.update_tags(args, tag_id)
|
||||||
|
|
||||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||||
|
|
||||||
response = {
|
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||||
'id': tag.id,
|
|
||||||
'name': tag.name,
|
|
||||||
'type': tag.type,
|
|
||||||
'binding_count': binding_count
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class TagBindingCreateApi(Resource):
|
class TagBindingCreateApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
parser.add_argument(
|
||||||
help='Tag IDs is required.')
|
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
)
|
||||||
help='Target ID is required.')
|
parser.add_argument(
|
||||||
parser.add_argument('type', type=str, location='json',
|
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
||||||
choices=Tag.TAG_TYPE_LIST,
|
)
|
||||||
nullable=True,
|
parser.add_argument(
|
||||||
help='Invalid tag type.')
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.save_tag_binding(args)
|
TagService.save_tag_binding(args)
|
||||||
|
|
||||||
|
@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class TagBindingDeleteApi(Resource):
|
class TagBindingDeleteApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||||
help='Tag ID is required.')
|
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
parser.add_argument(
|
||||||
help='Target ID is required.')
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
parser.add_argument('type', type=str, location='json',
|
)
|
||||||
choices=Tag.TAG_TYPE_LIST,
|
|
||||||
nullable=True,
|
|
||||||
help='Invalid tag type.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.delete_tag_binding(args)
|
TagService.delete_tag_binding(args)
|
||||||
|
|
||||||
return 200
|
return 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(TagListApi, '/tags')
|
api.add_resource(TagListApi, "/tags")
|
||||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
|
||||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
|
||||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -11,42 +10,39 @@ from . import api
|
||||||
|
|
||||||
|
|
||||||
class VersionApi(Resource):
|
class VersionApi(Resource):
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('current_version', type=str, required=True, location='args')
|
parser.add_argument("current_version", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'version': dify_config.CURRENT_VERSION,
|
"version": dify_config.CURRENT_VERSION,
|
||||||
'release_date': '',
|
"release_date": "",
|
||||||
'release_notes': '',
|
"release_notes": "",
|
||||||
'can_auto_update': False,
|
"can_auto_update": False,
|
||||||
'features': {
|
"features": {
|
||||||
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
|
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
|
||||||
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
|
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if not check_update_url:
|
if not check_update_url:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(check_update_url, {
|
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
|
||||||
'current_version': args.get('current_version')
|
|
||||||
})
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logging.warning("Check update version error: {}.".format(str(error)))
|
logging.warning("Check update version error: {}.".format(str(error)))
|
||||||
result['version'] = args.get('current_version')
|
result["version"] = args.get("current_version")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
content = json.loads(response.content)
|
content = json.loads(response.content)
|
||||||
result['version'] = content['version']
|
result["version"] = content["version"]
|
||||||
result['release_date'] = content['releaseDate']
|
result["release_date"] = content["releaseDate"]
|
||||||
result['release_notes'] = content['releaseNotes']
|
result["release_notes"] = content["releaseNotes"]
|
||||||
result['can_auto_update'] = content['canAutoUpdate']
|
result["can_auto_update"] = content["canAutoUpdate"]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(VersionApi, '/version')
|
api.add_resource(VersionApi, "/version")
|
||||||
|
|
|
@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr
|
||||||
|
|
||||||
|
|
||||||
class AccountInitApi(Resource):
|
class AccountInitApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
def post(self):
|
def post(self):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
if account.status == 'active':
|
if account.status == "active":
|
||||||
raise AccountAlreadyInitedError()
|
raise AccountAlreadyInitedError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
if dify_config.EDITION == 'CLOUD':
|
if dify_config.EDITION == "CLOUD":
|
||||||
parser.add_argument('invitation_code', type=str, location='json')
|
parser.add_argument("invitation_code", type=str, location="json")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
||||||
'interface_language', type=supported_language, required=True, location='json')
|
parser.add_argument("timezone", type=timezone, required=True, location="json")
|
||||||
parser.add_argument('timezone', type=timezone,
|
|
||||||
required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if dify_config.EDITION == 'CLOUD':
|
if dify_config.EDITION == "CLOUD":
|
||||||
if not args['invitation_code']:
|
if not args["invitation_code"]:
|
||||||
raise ValueError('invitation_code is required')
|
raise ValueError("invitation_code is required")
|
||||||
|
|
||||||
# check invitation code
|
# check invitation code
|
||||||
invitation_code = db.session.query(InvitationCode).filter(
|
invitation_code = (
|
||||||
InvitationCode.code == args['invitation_code'],
|
db.session.query(InvitationCode)
|
||||||
InvitationCode.status == 'unused',
|
.filter(
|
||||||
).first()
|
InvitationCode.code == args["invitation_code"],
|
||||||
|
InvitationCode.status == "unused",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not invitation_code:
|
if not invitation_code:
|
||||||
raise InvalidInvitationCodeError()
|
raise InvalidInvitationCodeError()
|
||||||
|
|
||||||
invitation_code.status = 'used'
|
invitation_code.status = "used"
|
||||||
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||||
invitation_code.used_by_account_id = account.id
|
invitation_code.used_by_account_id = account.id
|
||||||
|
|
||||||
account.interface_language = args['interface_language']
|
account.interface_language = args["interface_language"]
|
||||||
account.timezone = args['timezone']
|
account.timezone = args["timezone"]
|
||||||
account.interface_theme = 'light'
|
account.interface_theme = "light"
|
||||||
account.status = 'active'
|
account.status = "active"
|
||||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class AccountProfileApi(Resource):
|
class AccountProfileApi(Resource):
|
||||||
|
@ -90,15 +91,14 @@ class AccountNameApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate account name length
|
# Validate account name length
|
||||||
if len(args['name']) < 3 or len(args['name']) > 30:
|
if len(args["name"]) < 3 or len(args["name"]) > 30:
|
||||||
raise ValueError(
|
raise ValueError("Account name must be between 3 and 30 characters.")
|
||||||
"Account name must be between 3 and 30 characters.")
|
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, name=args['name'])
|
updated_account = AccountService.update_account(current_user, name=args["name"])
|
||||||
|
|
||||||
return updated_account
|
return updated_account
|
||||||
|
|
||||||
|
@ -110,10 +110,10 @@ class AccountAvatarApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('avatar', type=str, required=True, location='json')
|
parser.add_argument("avatar", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
|
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||||
|
|
||||||
return updated_account
|
return updated_account
|
||||||
|
|
||||||
|
@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument(
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
||||||
'interface_language', type=supported_language, required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
|
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||||
|
|
||||||
return updated_account
|
return updated_account
|
||||||
|
|
||||||
|
@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('interface_theme', type=str, choices=[
|
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
||||||
'light', 'dark'], required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
|
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||||
|
|
||||||
return updated_account
|
return updated_account
|
||||||
|
|
||||||
|
@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('timezone', type=str,
|
parser.add_argument("timezone", type=str, required=True, location="json")
|
||||||
required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||||
if args['timezone'] not in pytz.all_timezones:
|
if args["timezone"] not in pytz.all_timezones:
|
||||||
raise ValueError("Invalid timezone string.")
|
raise ValueError("Invalid timezone string.")
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
|
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
|
||||||
|
|
||||||
return updated_account
|
return updated_account
|
||||||
|
|
||||||
|
@ -177,20 +174,16 @@ class AccountPasswordApi(Resource):
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('password', type=str,
|
parser.add_argument("password", type=str, required=False, location="json")
|
||||||
required=False, location='json')
|
parser.add_argument("new_password", type=str, required=True, location="json")
|
||||||
parser.add_argument('new_password', type=str,
|
parser.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||||
required=True, location='json')
|
|
||||||
parser.add_argument('repeat_new_password', type=str,
|
|
||||||
required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args['new_password'] != args['repeat_new_password']:
|
if args["new_password"] != args["repeat_new_password"]:
|
||||||
raise RepeatPasswordNotMatchError()
|
raise RepeatPasswordNotMatchError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
AccountService.update_account_password(
|
AccountService.update_account_password(current_user, args["password"], args["new_password"])
|
||||||
current_user, args['password'], args['new_password'])
|
|
||||||
except ServiceCurrentPasswordIncorrectError:
|
except ServiceCurrentPasswordIncorrectError:
|
||||||
raise CurrentPasswordIncorrectError()
|
raise CurrentPasswordIncorrectError()
|
||||||
|
|
||||||
|
@ -199,14 +192,14 @@ class AccountPasswordApi(Resource):
|
||||||
|
|
||||||
class AccountIntegrateApi(Resource):
|
class AccountIntegrateApi(Resource):
|
||||||
integrate_fields = {
|
integrate_fields = {
|
||||||
'provider': fields.String,
|
"provider": fields.String,
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'is_bound': fields.Boolean,
|
"is_bound": fields.Boolean,
|
||||||
'link': fields.String
|
"link": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
integrate_list_fields = {
|
integrate_list_fields = {
|
||||||
'data': fields.List(fields.Nested(integrate_fields)),
|
"data": fields.List(fields.Nested(integrate_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
account_integrates = db.session.query(AccountIntegrate).filter(
|
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
|
||||||
AccountIntegrate.account_id == account.id).all()
|
|
||||||
|
|
||||||
base_url = request.url_root.rstrip('/')
|
base_url = request.url_root.rstrip("/")
|
||||||
oauth_base_path = "/console/api/oauth/login"
|
oauth_base_path = "/console/api/oauth/login"
|
||||||
providers = ["github", "google"]
|
providers = ["github", "google"]
|
||||||
|
|
||||||
|
@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource):
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
|
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
|
||||||
if existing_integrate:
|
if existing_integrate:
|
||||||
integrate_data.append({
|
integrate_data.append(
|
||||||
'id': existing_integrate.id,
|
{
|
||||||
'provider': provider,
|
"id": existing_integrate.id,
|
||||||
'created_at': existing_integrate.created_at,
|
"provider": provider,
|
||||||
'is_bound': True,
|
"created_at": existing_integrate.created_at,
|
||||||
'link': None
|
"is_bound": True,
|
||||||
})
|
"link": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
integrate_data.append({
|
integrate_data.append(
|
||||||
'id': None,
|
{
|
||||||
'provider': provider,
|
"id": None,
|
||||||
'created_at': None,
|
"provider": provider,
|
||||||
'is_bound': False,
|
"created_at": None,
|
||||||
'link': f'{base_url}{oauth_base_path}/{provider}'
|
"is_bound": False,
|
||||||
})
|
"link": f"{base_url}{oauth_base_path}/{provider}",
|
||||||
|
}
|
||||||
return {'data': integrate_data}
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return {"data": integrate_data}
|
||||||
|
|
||||||
|
|
||||||
# Register API resources
|
# Register API resources
|
||||||
api.add_resource(AccountInitApi, '/account/init')
|
api.add_resource(AccountInitApi, "/account/init")
|
||||||
api.add_resource(AccountProfileApi, '/account/profile')
|
api.add_resource(AccountProfileApi, "/account/profile")
|
||||||
api.add_resource(AccountNameApi, '/account/name')
|
api.add_resource(AccountNameApi, "/account/name")
|
||||||
api.add_resource(AccountAvatarApi, '/account/avatar')
|
api.add_resource(AccountAvatarApi, "/account/avatar")
|
||||||
api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
|
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
|
||||||
api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
|
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
|
||||||
api.add_resource(AccountTimezoneApi, '/account/timezone')
|
api.add_resource(AccountTimezoneApi, "/account/timezone")
|
||||||
api.add_resource(AccountPasswordApi, '/account/password')
|
api.add_resource(AccountPasswordApi, "/account/password")
|
||||||
api.add_resource(AccountIntegrateApi, '/account/integrates')
|
api.add_resource(AccountIntegrateApi, "/account/integrates")
|
||||||
# api.add_resource(AccountEmailApi, '/account/email')
|
# api.add_resource(AccountEmailApi, '/account/email')
|
||||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||||
|
|
|
@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class RepeatPasswordNotMatchError(BaseHTTPException):
|
class RepeatPasswordNotMatchError(BaseHTTPException):
|
||||||
error_code = 'repeat_password_not_match'
|
error_code = "repeat_password_not_match"
|
||||||
description = "New password and repeat password does not match."
|
description = "New password and repeat password does not match."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class CurrentPasswordIncorrectError(BaseHTTPException):
|
class CurrentPasswordIncorrectError(BaseHTTPException):
|
||||||
error_code = 'current_password_incorrect'
|
error_code = "current_password_incorrect"
|
||||||
description = "Current password is incorrect."
|
description = "Current password is incorrect."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderRequestFailedError(BaseHTTPException):
|
class ProviderRequestFailedError(BaseHTTPException):
|
||||||
error_code = 'provider_request_failed'
|
error_code = "provider_request_failed"
|
||||||
description = None
|
description = None
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class InvalidInvitationCodeError(BaseHTTPException):
|
class InvalidInvitationCodeError(BaseHTTPException):
|
||||||
error_code = 'invalid_invitation_code'
|
error_code = "invalid_invitation_code"
|
||||||
description = "Invalid invitation code."
|
description = "Invalid invitation code."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AccountAlreadyInitedError(BaseHTTPException):
|
class AccountAlreadyInitedError(BaseHTTPException):
|
||||||
error_code = 'account_already_inited'
|
error_code = "account_already_inited"
|
||||||
description = "The account has been initialized. Please refresh the page."
|
description = "The account has been initialized. Please refresh the page."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AccountNotInitializedError(BaseHTTPException):
|
class AccountNotInitializedError(BaseHTTPException):
|
||||||
error_code = 'account_not_initialized'
|
error_code = "account_not_initialized"
|
||||||
description = "The account has not been initialized yet. Please proceed with the initialization process first."
|
description = "The account has not been initialized yet. Please proceed with the initialization process first."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
|
@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# validate model load balancing credentials
|
# validate model load balancing credentials
|
||||||
|
@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args['model'],
|
model=args["model"],
|
||||||
model_type=args['model_type'],
|
model_type=args["model_type"],
|
||||||
credentials=args['credentials']
|
credentials=args["credentials"],
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
error = str(ex)
|
error = str(ex)
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
response = {"result": "success" if result else "error"}
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
response['error'] = error
|
response["error"] = error
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# validate model load balancing config credentials
|
# validate model load balancing config credentials
|
||||||
|
@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args['model'],
|
model=args["model"],
|
||||||
model_type=args['model_type'],
|
model_type=args["model_type"],
|
||||||
credentials=args['credentials'],
|
credentials=args["credentials"],
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
error = str(ex)
|
error = str(ex)
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
response = {"result": "success" if result else "error"}
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
response['error'] = error
|
response["error"] = error
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Load Balancing Config
|
# Load Balancing Config
|
||||||
api.add_resource(LoadBalancingCredentialsValidateApi,
|
api.add_resource(
|
||||||
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
|
LoadBalancingCredentialsValidateApi,
|
||||||
|
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
|
||||||
|
)
|
||||||
|
|
||||||
api.add_resource(LoadBalancingConfigCredentialsValidateApi,
|
api.add_resource(
|
||||||
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')
|
LoadBalancingConfigCredentialsValidateApi,
|
||||||
|
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||||
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ class MemberListApi(Resource):
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||||
return {'result': 'success', 'accounts': members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
|
|
||||||
class MemberInviteEmailApi(Resource):
|
class MemberInviteEmailApi(Resource):
|
||||||
|
@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('members')
|
@cloud_edition_billing_resource_check("members")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('emails', type=str, required=True, location='json', action='append')
|
parser.add_argument("emails", type=str, required=True, location="json", action="append")
|
||||||
parser.add_argument('role', type=str, required=True, default='admin', location='json')
|
parser.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||||
parser.add_argument('language', type=str, required=False, location='json')
|
parser.add_argument("language", type=str, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
invitee_emails = args['emails']
|
invitee_emails = args["emails"]
|
||||||
invitee_role = args['role']
|
invitee_role = args["role"]
|
||||||
interface_language = args['language']
|
interface_language = args["language"]
|
||||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
invitation_results = []
|
invitation_results = []
|
||||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||||
for invitee_email in invitee_emails:
|
for invitee_email in invitee_emails:
|
||||||
try:
|
try:
|
||||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
token = RegisterService.invite_new_member(
|
||||||
invitation_results.append({
|
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||||
'status': 'success',
|
)
|
||||||
'email': invitee_email,
|
invitation_results.append(
|
||||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
{
|
||||||
})
|
"status": "success",
|
||||||
|
"email": invitee_email,
|
||||||
|
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
|
||||||
|
}
|
||||||
|
)
|
||||||
except AccountAlreadyInTenantError:
|
except AccountAlreadyInTenantError:
|
||||||
invitation_results.append({
|
invitation_results.append(
|
||||||
'status': 'success',
|
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||||
'email': invitee_email,
|
)
|
||||||
'url': f'{console_web_url}/signin'
|
|
||||||
})
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
invitation_results.append({
|
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||||
'status': 'failed',
|
|
||||||
'email': invitee_email,
|
|
||||||
'message': str(e)
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'result': 'success',
|
"result": "success",
|
||||||
'invitation_results': invitation_results,
|
"invitation_results": invitation_results,
|
||||||
}, 201
|
}, 201
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource):
|
||||||
try:
|
try:
|
||||||
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
|
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
|
||||||
except services.errors.account.CannotOperateSelfError as e:
|
except services.errors.account.CannotOperateSelfError as e:
|
||||||
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
|
return {"code": "cannot-operate-self", "message": str(e)}, 400
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
return {'code': 'forbidden', 'message': str(e)}, 403
|
return {"code": "forbidden", "message": str(e)}, 403
|
||||||
except services.errors.account.MemberNotInTenantError as e:
|
except services.errors.account.MemberNotInTenantError as e:
|
||||||
return {'code': 'member-not-found', 'message': str(e)}, 404
|
return {"code": "member-not-found", "message": str(e)}, 404
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class MemberUpdateRoleApi(Resource):
|
class MemberUpdateRoleApi(Resource):
|
||||||
|
@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, member_id):
|
def put(self, member_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('role', type=str, required=True, location='json')
|
parser.add_argument("role", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
new_role = args['role']
|
new_role = args["role"]
|
||||||
|
|
||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
|
@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource):
|
||||||
|
|
||||||
# todo: 403
|
# todo: 403
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class DatasetOperatorMemberListApi(Resource):
|
class DatasetOperatorMemberListApi(Resource):
|
||||||
|
@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource):
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||||
return {'result': 'success', 'accounts': members}, 200
|
return {"result": "success", "accounts": members}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(MemberListApi, '/workspaces/current/members')
|
api.add_resource(MemberListApi, "/workspaces/current/members")
|
||||||
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
|
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
|
||||||
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
|
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
|
||||||
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
|
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
|
||||||
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
|
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
|
||||||
|
|
|
@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderListApi(Resource):
|
class ModelProviderListApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -25,21 +24,23 @@ class ModelProviderListApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model_type', type=str, required=False, nullable=True,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='args')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
nullable=True,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
provider_list = model_provider_service.get_provider_list(
|
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=args.get('model_type')
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonable_encoder({"data": provider_list})
|
return jsonable_encoder({"data": provider_list})
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderCredentialApi(Resource):
|
class ModelProviderCredentialApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
credentials = model_provider_service.get_provider_credentials(
|
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider=provider
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {"credentials": credentials}
|
||||||
"credentials": credentials
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderValidateApi(Resource):
|
class ModelProviderValidateApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.provider_credentials_validate(
|
model_provider_service.provider_credentials_validate(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||||
provider=provider,
|
|
||||||
credentials=args['credentials']
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
error = str(ex)
|
error = str(ex)
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
response = {"result": "success" if result else "error"}
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
response['error'] = error
|
response["error"] = error
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderApi(Resource):
|
class ModelProviderApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -103,21 +94,19 @@ class ModelProviderApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.save_provider_credentials(
|
model_provider_service.save_provider_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
|
||||||
provider=provider,
|
|
||||||
credentials=args['credentials']
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ValueError(str(ex))
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
return {'result': 'success'}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -127,12 +116,9 @@ class ModelProviderApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_provider_credentials(
|
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider=provider
|
|
||||||
)
|
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderIconApi(Resource):
|
class ModelProviderIconApi(Resource):
|
||||||
|
@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource):
|
||||||
def get(self, provider: str, icon_type: str, lang: str):
|
def get(self, provider: str, icon_type: str, lang: str):
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||||
provider=provider,
|
provider=provider, icon_type=icon_type, lang=lang
|
||||||
icon_type=icon_type,
|
|
||||||
lang=lang
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||||
|
|
||||||
|
|
||||||
class PreferredProviderTypeUpdateApi(Resource):
|
class PreferredProviderTypeUpdateApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=['system', 'custom'], location='json')
|
"preferred_provider_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=["system", "custom"],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.switch_preferred_provider(
|
model_provider_service.switch_preferred_provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
|
||||||
provider=provider,
|
|
||||||
preferred_provider_type=args['preferred_provider_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||||
|
@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if provider != 'anthropic':
|
if provider != "anthropic":
|
||||||
raise ValueError(f'provider name {provider} is invalid')
|
raise ValueError(f"provider name {provider} is invalid")
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
data = BillingService.get_model_provider_payment_link(
|
||||||
|
provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
account_id=current_user.id,
|
account_id=current_user.id,
|
||||||
prefilled_email=current_user.email)
|
prefilled_email=current_user.email,
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
result = model_provider_service.free_quota_submit(
|
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider=provider
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
result = model_provider_service.free_quota_qualification_verify(
|
result = model_provider_service.free_quota_qualification_verify(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
|
||||||
provider=provider,
|
|
||||||
token=args['token']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||||
|
|
||||||
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
|
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
|
||||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
|
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||||
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
|
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||||
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
|
api.add_resource(
|
||||||
'<string:icon_type>/<string:lang>')
|
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
|
||||||
|
)
|
||||||
|
|
||||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
api.add_resource(
|
||||||
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
|
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
|
||||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
)
|
||||||
'/workspaces/current/model-providers/<string:provider>/checkout-url')
|
api.add_resource(
|
||||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
|
||||||
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
|
)
|
||||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
api.add_resource(
|
||||||
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
|
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ModelProviderFreeQuotaQualificationVerifyApi,
|
||||||
|
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
|
||||||
|
)
|
||||||
|
|
|
@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
|
|
||||||
class DefaultModelApi(Resource):
|
class DefaultModelApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='args')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, model_type=args["model_type"]
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonable_encoder({
|
return jsonable_encoder({"data": default_model_entity})
|
||||||
"data": default_model_entity
|
|
||||||
})
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -46,38 +48,37 @@ class DefaultModelApi(Resource):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_settings = args['model_settings']
|
model_settings = args["model_settings"]
|
||||||
for model_setting in model_settings:
|
for model_setting in model_settings:
|
||||||
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
|
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
|
||||||
raise ValueError('invalid model type')
|
raise ValueError("invalid model type")
|
||||||
|
|
||||||
if 'provider' not in model_setting:
|
if "provider" not in model_setting:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'model' not in model_setting:
|
if "model" not in model_setting:
|
||||||
raise ValueError('invalid model')
|
raise ValueError("invalid model")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_default_model_of_model_type(
|
model_provider_service.update_default_model_of_model_type(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=model_setting['model_type'],
|
model_type=model_setting["model_type"],
|
||||||
provider=model_setting['provider'],
|
provider=model_setting["provider"],
|
||||||
model=model_setting['model']
|
model=model_setting["model"],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning(f"{model_setting['model_type']} save error")
|
logging.warning(f"{model_setting['model_type']} save error")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelApi(Resource):
|
class ModelProviderModelApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_provider(
|
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider=provider
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonable_encoder({
|
return jsonable_encoder({"data": models})
|
||||||
"data": models
|
|
||||||
})
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -104,62 +100,66 @@ class ModelProviderModelApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
|
type=str,
|
||||||
parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
|
required=True,
|
||||||
parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
||||||
if ('load_balancing' in args and args['load_balancing'] and
|
if (
|
||||||
'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
|
"load_balancing" in args
|
||||||
if 'configs' not in args['load_balancing']:
|
and args["load_balancing"]
|
||||||
raise ValueError('invalid load balancing configs')
|
and "enabled" in args["load_balancing"]
|
||||||
|
and args["load_balancing"]["enabled"]
|
||||||
|
):
|
||||||
|
if "configs" not in args["load_balancing"]:
|
||||||
|
raise ValueError("invalid load balancing configs")
|
||||||
|
|
||||||
# save load balancing configs
|
# save load balancing configs
|
||||||
model_load_balancing_service.update_load_balancing_configs(
|
model_load_balancing_service.update_load_balancing_configs(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args['model'],
|
model=args["model"],
|
||||||
model_type=args['model_type'],
|
model_type=args["model_type"],
|
||||||
configs=args['load_balancing']['configs']
|
configs=args["load_balancing"]["configs"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# enable load balancing
|
# enable load balancing
|
||||||
model_load_balancing_service.enable_model_load_balancing(
|
model_load_balancing_service.enable_model_load_balancing(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# disable load balancing
|
# disable load balancing
|
||||||
model_load_balancing_service.disable_model_load_balancing(
|
model_load_balancing_service.disable_model_load_balancing(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.get('config_from', '') != 'predefined-model':
|
if args.get("config_from", "") != "predefined-model":
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.save_model_credentials(
|
model_provider_service.save_model_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args['model'],
|
model=args["model"],
|
||||||
model_type=args['model_type'],
|
model_type=args["model_type"],
|
||||||
credentials=args['credentials']
|
credentials=args["credentials"],
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
logging.exception(f"save model credentials error: {ex}")
|
logging.exception(f"save model credentials error: {ex}")
|
||||||
raise ValueError(str(ex))
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -171,24 +171,26 @@ class ModelProviderModelApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_model_credentials(
|
model_provider_service.remove_model_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelCredentialApi(Resource):
|
class ModelProviderModelCredentialApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -196,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='args')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
credentials = model_provider_service.get_model_credentials(
|
credentials = model_provider_service.get_model_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
|
||||||
provider=provider,
|
|
||||||
model_type=args['model_type'],
|
|
||||||
model=args['model']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"load_balancing": {
|
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
|
||||||
"enabled": is_load_balancing_enabled,
|
|
||||||
"configs": load_balancing_configs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelEnableApi(Resource):
|
class ModelProviderModelEnableApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -235,24 +233,26 @@ class ModelProviderModelEnableApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.enable_model(
|
model_provider_service.enable_model(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelDisableApi(Resource):
|
class ModelProviderModelDisableApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -260,24 +260,26 @@ class ModelProviderModelDisableApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.disable_model(
|
model_provider_service.disable_model(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||||
provider=provider,
|
|
||||||
model=args['model'],
|
|
||||||
model_type=args['model_type']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelValidateApi(Resource):
|
class ModelProviderModelValidateApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -285,10 +287,16 @@ class ModelProviderModelValidateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
parser.add_argument(
|
||||||
choices=[mt.value for mt in ModelType], location='json')
|
"model_type",
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
type=str,
|
||||||
|
required=True,
|
||||||
|
nullable=False,
|
||||||
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
@ -300,48 +308,42 @@ class ModelProviderModelValidateApi(Resource):
|
||||||
model_provider_service.model_credentials_validate(
|
model_provider_service.model_credentials_validate(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args['model'],
|
model=args["model"],
|
||||||
model_type=args['model_type'],
|
model_type=args["model_type"],
|
||||||
credentials=args['credentials']
|
credentials=args["credentials"],
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
error = str(ex)
|
error = str(ex)
|
||||||
|
|
||||||
response = {'result': 'success' if result else 'error'}
|
response = {"result": "success" if result else "error"}
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
response['error'] = error
|
response["error"] = error
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderModelParameterRuleApi(Resource):
|
class ModelProviderModelParameterRuleApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id, provider=provider, model=args["model"]
|
||||||
provider=provider,
|
|
||||||
model=args['model']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonable_encoder({
|
return jsonable_encoder({"data": parameter_rules})
|
||||||
"data": parameter_rules
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderAvailableModelApi(Resource):
|
class ModelProviderAvailableModelApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -349,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_model_type(
|
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=model_type
|
return jsonable_encoder({"data": models})
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
|
||||||
|
api.add_resource(
|
||||||
|
ModelProviderModelEnableApi,
|
||||||
|
"/workspaces/current/model-providers/<string:provider>/models/enable",
|
||||||
|
endpoint="model-provider-model-enable",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ModelProviderModelDisableApi,
|
||||||
|
"/workspaces/current/model-providers/<string:provider>/models/disable",
|
||||||
|
endpoint="model-provider-model-disable",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonable_encoder({
|
api.add_resource(
|
||||||
"data": models
|
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
|
||||||
})
|
)
|
||||||
|
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||||
|
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||||
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
|
|
||||||
api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
|
|
||||||
endpoint='model-provider-model-enable')
|
|
||||||
api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
|
|
||||||
endpoint='model-provider-model-disable')
|
|
||||||
api.add_resource(ModelProviderModelCredentialApi,
|
|
||||||
'/workspaces/current/model-providers/<string:provider>/models/credentials')
|
|
||||||
api.add_resource(ModelProviderModelValidateApi,
|
|
||||||
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
|
|
||||||
|
|
||||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
|
||||||
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
|
|
||||||
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
|
|
||||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
|
||||||
|
|
|
@ -28,10 +28,18 @@ class ToolProviderListApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser()
|
||||||
req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
|
req.add_argument(
|
||||||
|
"type",
|
||||||
|
type=str,
|
||||||
|
choices=["builtin", "model", "api", "workflow"],
|
||||||
|
required=False,
|
||||||
|
nullable=True,
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
|
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderListToolsApi(Resource):
|
class ToolBuiltinProviderListToolsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
provider,
|
provider,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderDeleteApi(Resource):
|
class ToolBuiltinProviderDeleteApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -64,6 +75,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||||
provider,
|
provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderUpdateApi(Resource):
|
class ToolBuiltinProviderUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -76,7 +88,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
provider,
|
provider,
|
||||||
args['credentials'],
|
args["credentials"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||||
provider,
|
provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderIconApi(Resource):
|
class ToolBuiltinProviderIconApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
|
@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderAddApi(Resource):
|
class ToolApiProviderAddApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -120,30 +135,31 @@ class ToolApiProviderAddApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
|
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||||
parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.create_api_tool_provider(
|
return ApiToolManageService.create_api_tool_provider(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['provider'],
|
args["provider"],
|
||||||
args['icon'],
|
args["icon"],
|
||||||
args['credentials'],
|
args["credentials"],
|
||||||
args['schema_type'],
|
args["schema_type"],
|
||||||
args['schema'],
|
args["schema"],
|
||||||
args.get('privacy_policy', ''),
|
args.get("privacy_policy", ""),
|
||||||
args.get('custom_disclaimer', ''),
|
args.get("custom_disclaimer", ""),
|
||||||
args.get('labels', []),
|
args.get("labels", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('url', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||||
current_user.id,
|
current_user.id,
|
||||||
current_user.current_tenant_id,
|
current_user.current_tenant_id,
|
||||||
args['url'],
|
args["url"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderListToolsApi(Resource):
|
class ToolApiProviderListToolsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
|
return jsonable_encoder(
|
||||||
|
ApiToolManageService.list_api_tool_provider_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['provider'],
|
args["provider"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderUpdateApi(Resource):
|
class ToolApiProviderUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -193,32 +213,33 @@ class ToolApiProviderUpdateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
|
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||||
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
|
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.update_api_tool_provider(
|
return ApiToolManageService.update_api_tool_provider(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['provider'],
|
args["provider"],
|
||||||
args['original_provider'],
|
args["original_provider"],
|
||||||
args['icon'],
|
args["icon"],
|
||||||
args['credentials'],
|
args["credentials"],
|
||||||
args['schema_type'],
|
args["schema_type"],
|
||||||
args['schema'],
|
args["schema"],
|
||||||
args['privacy_policy'],
|
args["privacy_policy"],
|
||||||
args['custom_disclaimer'],
|
args["custom_disclaimer"],
|
||||||
args.get('labels', []),
|
args.get("labels", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderDeleteApi(Resource):
|
class ToolApiProviderDeleteApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -232,16 +253,17 @@ class ToolApiProviderDeleteApi(Resource):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.delete_api_tool_provider(
|
return ApiToolManageService.delete_api_tool_provider(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['provider'],
|
args["provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderGetApi(Resource):
|
class ToolApiProviderGetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.get_api_tool_provider(
|
return ApiToolManageService.get_api_tool_provider(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['provider'],
|
args["provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderSchemaApi(Resource):
|
class ToolApiProviderSchemaApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.parser_api_schema(
|
return ApiToolManageService.parser_api_schema(
|
||||||
schema=args['schema'],
|
schema=args["schema"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderPreviousTestApi(Resource):
|
class ToolApiProviderPreviousTestApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
|
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return ApiToolManageService.test_api_tool_preview(
|
return ApiToolManageService.test_api_tool_preview(
|
||||||
current_user.current_tenant_id,
|
current_user.current_tenant_id,
|
||||||
args['provider_name'] if args['provider_name'] else '',
|
args["provider_name"] if args["provider_name"] else "",
|
||||||
args['tool_name'],
|
args["tool_name"],
|
||||||
args['credentials'],
|
args["credentials"],
|
||||||
args['parameters'],
|
args["parameters"],
|
||||||
args['schema_type'],
|
args["schema_type"],
|
||||||
args['schema'],
|
args["schema"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowProviderCreateApi(Resource):
|
class ToolWorkflowProviderCreateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -322,30 +348,31 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
|
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
|
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
|
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
|
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
|
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
|
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||||
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
return WorkflowToolManageService.create_workflow_tool(
|
return WorkflowToolManageService.create_workflow_tool(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_app_id'],
|
args["workflow_app_id"],
|
||||||
args['name'],
|
args["name"],
|
||||||
args['label'],
|
args["label"],
|
||||||
args['icon'],
|
args["icon"],
|
||||||
args['description'],
|
args["description"],
|
||||||
args['parameters'],
|
args["parameters"],
|
||||||
args['privacy_policy'],
|
args["privacy_policy"],
|
||||||
args.get('labels', []),
|
args.get("labels", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowProviderUpdateApi(Resource):
|
class ToolWorkflowProviderUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -358,33 +385,34 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
|
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
|
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
|
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
|
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
|
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
|
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||||
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
if not args['workflow_tool_id']:
|
if not args["workflow_tool_id"]:
|
||||||
raise ValueError('incorrect workflow_tool_id')
|
raise ValueError("incorrect workflow_tool_id")
|
||||||
|
|
||||||
return WorkflowToolManageService.update_workflow_tool(
|
return WorkflowToolManageService.update_workflow_tool(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_tool_id'],
|
args["workflow_tool_id"],
|
||||||
args['name'],
|
args["name"],
|
||||||
args['label'],
|
args["label"],
|
||||||
args['icon'],
|
args["icon"],
|
||||||
args['description'],
|
args["description"],
|
||||||
args['parameters'],
|
args["parameters"],
|
||||||
args['privacy_policy'],
|
args["privacy_policy"],
|
||||||
args.get('labels', []),
|
args.get("labels", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowProviderDeleteApi(Resource):
|
class ToolWorkflowProviderDeleteApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -397,16 +425,17 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser()
|
||||||
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
|
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
return WorkflowToolManageService.delete_workflow_tool(
|
return WorkflowToolManageService.delete_workflow_tool(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_tool_id'],
|
args["workflow_tool_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowProviderGetApi(Resource):
|
class ToolWorkflowProviderGetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
|
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
|
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.get('workflow_tool_id'):
|
if args.get("workflow_tool_id"):
|
||||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_tool_id'],
|
args["workflow_tool_id"],
|
||||||
)
|
)
|
||||||
elif args.get('workflow_app_id'):
|
elif args.get("workflow_app_id"):
|
||||||
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
|
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_app_id'],
|
args["workflow_app_id"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError('incorrect workflow_tool_id or workflow_app_id')
|
raise ValueError("incorrect workflow_tool_id or workflow_app_id")
|
||||||
|
|
||||||
return jsonable_encoder(tool)
|
return jsonable_encoder(tool)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowProviderListToolApi(Resource):
|
class ToolWorkflowProviderListToolApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
|
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
|
return jsonable_encoder(
|
||||||
|
WorkflowToolManageService.list_single_workflow_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args['workflow_tool_id'],
|
args["workflow_tool_id"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinListApi(Resource):
|
class ToolBuiltinListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -465,10 +498,16 @@ class ToolBuiltinListApi(Resource):
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
|
return jsonable_encoder(
|
||||||
|
[
|
||||||
|
provider.to_dict()
|
||||||
|
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
)])
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiListApi(Resource):
|
class ToolApiListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -478,10 +517,16 @@ class ToolApiListApi(Resource):
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
|
return jsonable_encoder(
|
||||||
|
[
|
||||||
|
provider.to_dict()
|
||||||
|
for provider in ApiToolManageService.list_api_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
)])
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolWorkflowListApi(Resource):
|
class ToolWorkflowListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -491,10 +536,16 @@ class ToolWorkflowListApi(Resource):
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
return jsonable_encoder(
|
||||||
|
[
|
||||||
|
provider.to_dict()
|
||||||
|
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
||||||
user_id,
|
user_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
)])
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolLabelsApi(Resource):
|
class ToolLabelsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -503,36 +554,41 @@ class ToolLabelsApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||||
|
|
||||||
|
|
||||||
# tool provider
|
# tool provider
|
||||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||||
|
|
||||||
# builtin tool provider
|
# builtin tool provider
|
||||||
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
|
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
|
||||||
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
|
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
|
||||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
|
||||||
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
|
api.add_resource(
|
||||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
|
||||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
|
||||||
|
)
|
||||||
|
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
|
||||||
|
|
||||||
# api tool provider
|
# api tool provider
|
||||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
|
||||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
|
||||||
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
|
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
|
||||||
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
|
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
|
||||||
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
|
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
|
||||||
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
|
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
|
||||||
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
|
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
|
||||||
|
|
||||||
# workflow tool provider
|
# workflow tool provider
|
||||||
api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
|
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
|
||||||
api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
|
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
|
||||||
api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
|
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
|
||||||
api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
|
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
|
||||||
api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
|
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
|
||||||
|
|
||||||
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
|
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
|
||||||
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
|
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
|
||||||
api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
|
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
|
||||||
|
|
||||||
api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')
|
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
|
||||||
|
|
|
@ -26,39 +26,34 @@ from services.file_service import FileService
|
||||||
from services.workspace_service import WorkspaceService
|
from services.workspace_service import WorkspaceService
|
||||||
|
|
||||||
provider_fields = {
|
provider_fields = {
|
||||||
'provider_name': fields.String,
|
"provider_name": fields.String,
|
||||||
'provider_type': fields.String,
|
"provider_type": fields.String,
|
||||||
'is_valid': fields.Boolean,
|
"is_valid": fields.Boolean,
|
||||||
'token_is_set': fields.Boolean,
|
"token_is_set": fields.Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
tenant_fields = {
|
tenant_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'plan': fields.String,
|
"plan": fields.String,
|
||||||
'status': fields.String,
|
"status": fields.String,
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'role': fields.String,
|
"role": fields.String,
|
||||||
'in_trial': fields.Boolean,
|
"in_trial": fields.Boolean,
|
||||||
'trial_end_reason': fields.String,
|
"trial_end_reason": fields.String,
|
||||||
'custom_config': fields.Raw(attribute='custom_config'),
|
"custom_config": fields.Raw(attribute="custom_config"),
|
||||||
}
|
}
|
||||||
|
|
||||||
tenants_fields = {
|
tenants_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'plan': fields.String,
|
"plan": fields.String,
|
||||||
'status': fields.String,
|
"status": fields.String,
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'current': fields.Boolean
|
"current": fields.Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
workspace_fields = {
|
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TenantListApi(Resource):
|
class TenantListApi(Resource):
|
||||||
|
@ -71,7 +66,7 @@ class TenantListApi(Resource):
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if tenant.id == current_user.current_tenant_id:
|
if tenant.id == current_user.current_tenant_id:
|
||||||
tenant.current = True # Set current=True for current tenant
|
tenant.current = True # Set current=True for current tenant
|
||||||
return {'workspaces': marshal(tenants, tenants_fields)}, 200
|
return {"workspaces": marshal(tenants, tenants_fields)}, 200
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceListApi(Resource):
|
class WorkspaceListApi(Resource):
|
||||||
|
@ -79,30 +74,36 @@ class WorkspaceListApi(Resource):
|
||||||
@admin_required
|
@admin_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
|
tenants = (
|
||||||
.paginate(page=args['page'], per_page=args['limit'])
|
db.session.query(Tenant)
|
||||||
|
.order_by(Tenant.created_at.desc())
|
||||||
|
.paginate(page=args["page"], per_page=args["limit"])
|
||||||
|
)
|
||||||
|
|
||||||
has_more = False
|
has_more = False
|
||||||
if len(tenants.items) == args['limit']:
|
if len(tenants.items) == args["limit"]:
|
||||||
current_page_first_tenant = tenants[-1]
|
current_page_first_tenant = tenants[-1]
|
||||||
rest_count = db.session.query(Tenant).filter(
|
rest_count = (
|
||||||
Tenant.created_at < current_page_first_tenant.created_at,
|
db.session.query(Tenant)
|
||||||
Tenant.id != current_page_first_tenant.id
|
.filter(
|
||||||
).count()
|
Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
if rest_count > 0:
|
if rest_count > 0:
|
||||||
has_more = True
|
has_more = True
|
||||||
total = db.session.query(Tenant).count()
|
total = db.session.query(Tenant).count()
|
||||||
return {
|
return {
|
||||||
'data': marshal(tenants.items, workspace_fields),
|
"data": marshal(tenants.items, workspace_fields),
|
||||||
'has_more': has_more,
|
"has_more": has_more,
|
||||||
'limit': args['limit'],
|
"limit": args["limit"],
|
||||||
'page': args['page'],
|
"page": args["page"],
|
||||||
'total': total
|
"total": total,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,8 +113,8 @@ class TenantApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(tenant_fields)
|
@marshal_with(tenant_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if request.path == '/info':
|
if request.path == "/info":
|
||||||
logging.warning('Deprecated URL /info was used.')
|
logging.warning("Deprecated URL /info was used.")
|
||||||
|
|
||||||
tenant = current_user.current_tenant
|
tenant = current_user.current_tenant
|
||||||
|
|
||||||
|
@ -125,7 +126,7 @@ class TenantApi(Resource):
|
||||||
tenant = tenants[0]
|
tenant = tenants[0]
|
||||||
# else, raise Unauthorized
|
# else, raise Unauthorized
|
||||||
else:
|
else:
|
||||||
raise Unauthorized('workspace is archived')
|
raise Unauthorized("workspace is archived")
|
||||||
|
|
||||||
return WorkspaceService.get_tenant_info(tenant), 200
|
return WorkspaceService.get_tenant_info(tenant), 200
|
||||||
|
|
||||||
|
@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('tenant_id', type=str, required=True, location='json')
|
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if tenant_id is valid, 403 if not
|
# check if tenant_id is valid, 403 if not
|
||||||
try:
|
try:
|
||||||
TenantService.switch_tenant(current_user, args['tenant_id'])
|
TenantService.switch_tenant(current_user, args["tenant_id"])
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AccountNotLinkTenantError("Account not link tenant")
|
raise AccountNotLinkTenantError("Account not link tenant")
|
||||||
|
|
||||||
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
|
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
|
||||||
|
|
||||||
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||||
|
|
||||||
|
|
||||||
class CustomConfigWorkspaceApi(Resource):
|
class CustomConfigWorkspaceApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('workspace_custom')
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('remove_webapp_brand', type=bool, location='json')
|
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||||
parser.add_argument('replace_webapp_logo', type=str, location='json')
|
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict = {
|
||||||
'remove_webapp_brand': args['remove_webapp_brand'],
|
"remove_webapp_brand": args["remove_webapp_brand"],
|
||||||
'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') ,
|
"replace_webapp_logo": args["replace_webapp_logo"]
|
||||||
|
if args["replace_webapp_logo"] is not None
|
||||||
|
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||||
}
|
}
|
||||||
|
|
||||||
tenant.custom_config_dict = custom_config_dict
|
tenant.custom_config_dict = custom_config_dict
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||||
|
|
||||||
|
|
||||||
class WebappLogoWorkspaceApi(Resource):
|
class WebappLogoWorkspaceApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check('workspace_custom')
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
|
|
||||||
extension = file.filename.split('.')[-1]
|
extension = file.filename.split(".")[-1]
|
||||||
if extension.lower() not in ['svg', 'png']:
|
if extension.lower() not in ["svg", "png"]:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -202,13 +205,13 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
return { 'id': upload_file.id }, 201
|
return {"id": upload_file.id}, 201
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
|
||||||
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
|
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
|
||||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
|
||||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
|
||||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
|
||||||
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
|
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
|
||||||
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
|
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
|
||||||
|
|
|
@ -16,7 +16,7 @@ def account_initialization_required(view):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
if account.status == 'uninitialized':
|
if account.status == "uninitialized":
|
||||||
raise AccountNotInitializedError()
|
raise AccountNotInitializedError()
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -27,7 +27,7 @@ def account_initialization_required(view):
|
||||||
def only_edition_cloud(view):
|
def only_edition_cloud(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if dify_config.EDITION != 'CLOUD':
|
if dify_config.EDITION != "CLOUD":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -38,7 +38,7 @@ def only_edition_cloud(view):
|
||||||
def only_edition_self_hosted(view):
|
def only_edition_self_hosted(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if dify_config.EDITION != 'SELF_HOSTED':
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -46,8 +46,9 @@ def only_edition_self_hosted(view):
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str,
|
def cloud_edition_billing_resource_check(
|
||||||
error_msg: str = "You have reached the limit of your subscription."):
|
resource: str, error_msg: str = "You have reached the limit of your subscription."
|
||||||
|
):
|
||||||
def interceptor(view):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
|
@ -58,22 +59,22 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||||
vector_space = features.vector_space
|
vector_space = features.vector_space
|
||||||
documents_upload_quota = features.documents_upload_quota
|
documents_upload_quota = features.documents_upload_quota
|
||||||
annotation_quota_limit = features.annotation_quota_limit
|
annotation_quota_limit = features.annotation_quota_limit
|
||||||
if resource == 'members' and 0 < members.limit <= members.size:
|
if resource == "members" and 0 < members.limit <= members.size:
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
elif resource == 'apps' and 0 < apps.limit <= apps.size:
|
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||||
source = request.args.get('source')
|
source = request.args.get("source")
|
||||||
if source == 'datasets':
|
if source == "datasets":
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
else:
|
else:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
elif resource == 'workspace_custom' and not features.can_replace_logo:
|
elif resource == "workspace_custom" and not features.can_replace_logo:
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
else:
|
else:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -85,15 +86,17 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str,
|
def cloud_edition_billing_knowledge_limit_check(
|
||||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
|
resource: str,
|
||||||
|
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||||
|
):
|
||||||
def interceptor(view):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == 'add_segment':
|
if resource == "add_segment":
|
||||||
if features.billing.subscription.plan == 'sandbox':
|
if features.billing.subscription.plan == "sandbox":
|
||||||
abort(403, error_msg)
|
abort(403, error_msg)
|
||||||
else:
|
else:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -112,7 +115,7 @@ def cloud_utm_record(view):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
utm_info = request.cookies.get('utm_info')
|
utm_info = request.cookies.get("utm_info")
|
||||||
|
|
||||||
if utm_info:
|
if utm_info:
|
||||||
utm_info = json.loads(utm_info)
|
utm_info = json.loads(utm_info)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
bp = Blueprint('files', __name__)
|
bp = Blueprint("files", __name__)
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,20 +13,15 @@ class ImagePreviewApi(Resource):
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
timestamp = request.args.get('timestamp')
|
timestamp = request.args.get("timestamp")
|
||||||
nonce = request.args.get('nonce')
|
nonce = request.args.get("nonce")
|
||||||
sign = request.args.get('sign')
|
sign = request.args.get("sign")
|
||||||
|
|
||||||
if not timestamp or not nonce or not sign:
|
if not timestamp or not nonce or not sign:
|
||||||
return {'content': 'Invalid request.'}, 400
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_image_preview(
|
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
|
||||||
file_id,
|
|
||||||
timestamp,
|
|
||||||
nonce,
|
|
||||||
sign
|
|
||||||
)
|
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
@ -38,10 +33,10 @@ class WorkspaceWebappLogoApi(Resource):
|
||||||
workspace_id = str(workspace_id)
|
workspace_id = str(workspace_id)
|
||||||
|
|
||||||
custom_config = TenantService.get_custom_config(workspace_id)
|
custom_config = TenantService.get_custom_config(workspace_id)
|
||||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
|
||||||
|
|
||||||
if not webapp_logo_file_id:
|
if not webapp_logo_file_id:
|
||||||
raise NotFound('webapp logo is not found')
|
raise NotFound("webapp logo is not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_public_image_preview(
|
generator, mimetype = FileService.get_public_image_preview(
|
||||||
|
@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource):
|
||||||
return Response(generator, mimetype=mimetype)
|
return Response(generator, mimetype=mimetype)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
||||||
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
|
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
|
@ -13,18 +13,19 @@ class ToolFilePreviewApi(Resource):
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
parser.add_argument('timestamp', type=str, required=True, location='args')
|
parser.add_argument("timestamp", type=str, required=True, location="args")
|
||||||
parser.add_argument('nonce', type=str, required=True, location='args')
|
parser.add_argument("nonce", type=str, required=True, location="args")
|
||||||
parser.add_argument('sign', type=str, required=True, location='args')
|
parser.add_argument("sign", type=str, required=True, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not ToolFileManager.verify_file(file_id=file_id,
|
if not ToolFileManager.verify_file(
|
||||||
timestamp=args['timestamp'],
|
file_id=file_id,
|
||||||
nonce=args['nonce'],
|
timestamp=args["timestamp"],
|
||||||
sign=args['sign'],
|
nonce=args["nonce"],
|
||||||
|
sign=args["sign"],
|
||||||
):
|
):
|
||||||
raise Forbidden('Invalid request.')
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = ToolFileManager.get_file_generator_by_tool_file_id(
|
result = ToolFileManager.get_file_generator_by_tool_file_id(
|
||||||
|
@ -32,7 +33,7 @@ class ToolFilePreviewApi(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
raise NotFound('file is not found')
|
raise NotFound("file is not found")
|
||||||
|
|
||||||
generator, mimetype = result
|
generator, mimetype = result
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -40,9 +41,11 @@ class ToolFilePreviewApi(Resource):
|
||||||
|
|
||||||
return Response(generator, mimetype=mimetype)
|
return Response(generator, mimetype=mimetype)
|
||||||
|
|
||||||
api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>')
|
|
||||||
|
api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
|
@ -2,8 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
|
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
from .workspace import workspace
|
from .workspace import workspace
|
||||||
|
|
||||||
|
|
|
@ -9,29 +9,24 @@ from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseWorkspace(Resource):
|
class EnterpriseWorkspace(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@inner_api_only
|
@inner_api_only
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, location='json')
|
parser.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument('owner_email', type=str, required=True, location='json')
|
parser.add_argument("owner_email", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = Account.query.filter_by(email=args['owner_email']).first()
|
account = Account.query.filter_by(email=args["owner_email"]).first()
|
||||||
if account is None:
|
if account is None:
|
||||||
return {
|
return {"message": "owner account not found."}, 404
|
||||||
'message': 'owner account not found.'
|
|
||||||
}, 404
|
|
||||||
|
|
||||||
tenant = TenantService.create_tenant(args['name'])
|
tenant = TenantService.create_tenant(args["name"])
|
||||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||||
|
|
||||||
tenant_was_created.send(tenant)
|
tenant_was_created.send(tenant)
|
||||||
|
|
||||||
return {
|
return {"message": "enterprise workspace created."}
|
||||||
'message': 'enterprise workspace created.'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
|
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
|
||||||
|
|
|
@ -17,7 +17,7 @@ def inner_api_only(view):
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
# get header 'X-Inner-Api-Key'
|
# get header 'X-Inner-Api-Key'
|
||||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||||
abort(401)
|
abort(401)
|
||||||
|
|
||||||
|
@ -33,29 +33,29 @@ def inner_api_user_auth(view):
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
# get header 'X-Inner-Api-Key'
|
# get header 'X-Inner-Api-Key'
|
||||||
authorization = request.headers.get('Authorization')
|
authorization = request.headers.get("Authorization")
|
||||||
if not authorization:
|
if not authorization:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
parts = authorization.split(':')
|
parts = authorization.split(":")
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
user_id, token = parts
|
user_id, token = parts
|
||||||
if ' ' in user_id:
|
if " " in user_id:
|
||||||
user_id = user_id.split(' ')[1]
|
user_id = user_id.split(" ")[1]
|
||||||
|
|
||||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||||
|
|
||||||
data_to_sign = f'DIFY {user_id}'
|
data_to_sign = f"DIFY {user_id}"
|
||||||
|
|
||||||
signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
|
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
|
||||||
signature = b64encode(signature.digest()).decode('utf-8')
|
signature = b64encode(signature.digest()).decode("utf-8")
|
||||||
|
|
||||||
if signature != token:
|
if signature != token:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from flask_restful import Resource, fields, marshal_with
|
from flask_restful import Resource, fields, marshal_with
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
@ -13,32 +12,30 @@ class AppParameterApi(Resource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
variable_fields = {
|
variable_fields = {
|
||||||
'key': fields.String,
|
"key": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'description': fields.String,
|
"description": fields.String,
|
||||||
'type': fields.String,
|
"type": fields.String,
|
||||||
'default': fields.String,
|
"default": fields.String,
|
||||||
'max_length': fields.Integer,
|
"max_length": fields.Integer,
|
||||||
'options': fields.List(fields.String)
|
"options": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
system_parameters_fields = {
|
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||||
'image_file_size_limit': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
parameters_fields = {
|
parameters_fields = {
|
||||||
'opening_statement': fields.String,
|
"opening_statement": fields.String,
|
||||||
'suggested_questions': fields.Raw,
|
"suggested_questions": fields.Raw,
|
||||||
'suggested_questions_after_answer': fields.Raw,
|
"suggested_questions_after_answer": fields.Raw,
|
||||||
'speech_to_text': fields.Raw,
|
"speech_to_text": fields.Raw,
|
||||||
'text_to_speech': fields.Raw,
|
"text_to_speech": fields.Raw,
|
||||||
'retriever_resource': fields.Raw,
|
"retriever_resource": fields.Raw,
|
||||||
'annotation_reply': fields.Raw,
|
"annotation_reply": fields.Raw,
|
||||||
'more_like_this': fields.Raw,
|
"more_like_this": fields.Raw,
|
||||||
'user_input_form': fields.Raw,
|
"user_input_form": fields.Raw,
|
||||||
'sensitive_word_avoidance': fields.Raw,
|
"sensitive_word_avoidance": fields.Raw,
|
||||||
'file_upload': fields.Raw,
|
"file_upload": fields.Raw,
|
||||||
'system_parameters': fields.Nested(system_parameters_fields)
|
"system_parameters": fields.Nested(system_parameters_fields),
|
||||||
}
|
}
|
||||||
|
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
|
@ -56,30 +53,35 @@ class AppParameterApi(Resource):
|
||||||
app_model_config = app_model.app_model_config
|
app_model_config = app_model.app_model_config
|
||||||
features_dict = app_model_config.to_dict()
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
user_input_form = features_dict.get('user_input_form', [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'opening_statement': features_dict.get('opening_statement'),
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
'suggested_questions': features_dict.get('suggested_questions', []),
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
|
"suggested_questions_after_answer": features_dict.get(
|
||||||
{"enabled": False}),
|
"suggested_questions_after_answer", {"enabled": False}
|
||||||
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
|
),
|
||||||
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
'user_input_form': user_input_form,
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
|
"user_input_form": user_input_form,
|
||||||
{"enabled": False, "type": "", "configs": []}),
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
'file_upload': features_dict.get('file_upload', {"image": {
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"number_limits": 3,
|
"number_limits": 3,
|
||||||
"detail": "high",
|
"detail": "high",
|
||||||
"transfer_methods": ["remote_url", "local_file"]
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
}}),
|
|
||||||
'system_parameters': {
|
|
||||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,16 +91,14 @@ class AppMetaApi(Resource):
|
||||||
"""Get app meta"""
|
"""Get app meta"""
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
class AppInfoApi(Resource):
|
class AppInfoApi(Resource):
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Get app information"""
|
"""Get app information"""
|
||||||
return {
|
return {"name": app_model.name, "description": app_model.description}
|
||||||
'name':app_model.name,
|
|
||||||
'description':app_model.description
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppParameterApi, '/parameters')
|
api.add_resource(AppParameterApi, "/parameters")
|
||||||
api.add_resource(AppMetaApi, '/meta')
|
api.add_resource(AppMetaApi, "/meta")
|
||||||
api.add_resource(AppInfoApi, '/info')
|
api.add_resource(AppInfoApi, "/info")
|
||||||
|
|
|
@ -33,14 +33,10 @@ from services.errors.audio import (
|
||||||
class AudioApi(Resource):
|
class AudioApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AudioService.transcript_asr(
|
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||||
app_model=app_model,
|
|
||||||
file=file,
|
|
||||||
end_user=end_user
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
@ -74,30 +70,32 @@ class TextApi(Resource):
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument('text', type=str, location='json')
|
parser.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument("streaming", type=bool, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id', None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get('text', None)
|
text = args.get("text", None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (
|
||||||
|
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict
|
||||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
):
|
||||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
|
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
voice = (
|
||||||
|
args.get("voice")
|
||||||
|
if args.get("voice")
|
||||||
|
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model,
|
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||||
message_id=message_id,
|
|
||||||
end_user=end_user.external_user_id,
|
|
||||||
voice=voice,
|
|
||||||
text=text
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -127,5 +125,5 @@ class TextApi(Resource):
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AudioApi, '/audio-to-text')
|
api.add_resource(AudioApi, "/audio-to-text")
|
||||||
api.add_resource(TextApi, '/text-to-audio')
|
api.add_resource(TextApi, "/text-to-audio")
|
||||||
|
|
|
@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService
|
||||||
class CompletionApi(Resource):
|
class CompletionApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, location='json', default='')
|
parser.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
|
@ -84,12 +84,12 @@ class CompletionApi(Resource):
|
||||||
class CompletionStopApi(Resource):
|
class CompletionStopApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ChatApi(Resource):
|
class ChatApi(Resource):
|
||||||
|
@ -100,25 +100,21 @@ class ChatApi(Resource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, required=True, location='json')
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||||
user=end_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -153,10 +149,10 @@ class ChatStopApi(Resource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CompletionApi, '/completion-messages')
|
api.add_resource(CompletionApi, "/completion-messages")
|
||||||
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
|
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
|
||||||
api.add_resource(ChatApi, '/chat-messages')
|
api.add_resource(ChatApi, "/chat-messages")
|
||||||
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
|
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")
|
||||||
|
|
|
@ -14,7 +14,6 @@ from services.conversation_service import ConversationService
|
||||||
|
|
||||||
|
|
||||||
class ConversationApi(Resource):
|
class ConversationApi(Resource):
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model: App, end_user: EndUser):
|
def get(self, app_model: App, end_user: EndUser):
|
||||||
|
@ -23,20 +22,26 @@ class ConversationApi(Resource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
parser.add_argument(
|
||||||
required=False, default='-updated_at', location='args')
|
"sort_by",
|
||||||
|
type=str,
|
||||||
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
|
required=False,
|
||||||
|
default="-updated_at",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.pagination_by_last_id(
|
return ConversationService.pagination_by_last_id(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
last_id=args['last_id'],
|
last_id=args["last_id"],
|
||||||
limit=args['limit'],
|
limit=args["limit"],
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
sort_by=args['sort_by']
|
sort_by=args["sort_by"],
|
||||||
)
|
)
|
||||||
except services.errors.conversation.LastConversationNotExistsError:
|
except services.errors.conversation.LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
@ -56,11 +61,10 @@ class ConversationDetailApi(Resource):
|
||||||
ConversationService.delete(app_model, conversation_id, end_user)
|
ConversationService.delete(app_model, conversation_id, end_user)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenameApi(Resource):
|
class ConversationRenameApi(Resource):
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||||
|
@ -71,22 +75,16 @@ class ConversationRenameApi(Resource):
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=False, location='json')
|
parser.add_argument("name", type=str, required=False, location="json")
|
||||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||||
app_model,
|
|
||||||
conversation_id,
|
|
||||||
end_user,
|
|
||||||
args['name'],
|
|
||||||
args['auto_generate']
|
|
||||||
)
|
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
|
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
|
||||||
api.add_resource(ConversationApi, '/conversations')
|
api.add_resource(ConversationApi, "/conversations")
|
||||||
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
|
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
|
||||||
|
|
|
@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class AppUnavailableError(BaseHTTPException):
|
class AppUnavailableError(BaseHTTPException):
|
||||||
error_code = 'app_unavailable'
|
error_code = "app_unavailable"
|
||||||
description = "App unavailable, please check your app configurations."
|
description = "App unavailable, please check your app configurations."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotCompletionAppError(BaseHTTPException):
|
class NotCompletionAppError(BaseHTTPException):
|
||||||
error_code = 'not_completion_app'
|
error_code = "not_completion_app"
|
||||||
description = "Please check if your Completion app mode matches the right API route."
|
description = "Please check if your Completion app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotChatAppError(BaseHTTPException):
|
class NotChatAppError(BaseHTTPException):
|
||||||
error_code = 'not_chat_app'
|
error_code = "not_chat_app"
|
||||||
description = "Please check if your app mode matches the right API route."
|
description = "Please check if your app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotWorkflowAppError(BaseHTTPException):
|
class NotWorkflowAppError(BaseHTTPException):
|
||||||
error_code = 'not_workflow_app'
|
error_code = "not_workflow_app"
|
||||||
description = "Please check if your app mode matches the right API route."
|
description = "Please check if your app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ConversationCompletedError(BaseHTTPException):
|
class ConversationCompletedError(BaseHTTPException):
|
||||||
error_code = 'conversation_completed'
|
error_code = "conversation_completed"
|
||||||
description = "The conversation has ended. Please start a new conversation."
|
description = "The conversation has ended. Please start a new conversation."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotInitializeError(BaseHTTPException):
|
class ProviderNotInitializeError(BaseHTTPException):
|
||||||
error_code = 'provider_not_initialize'
|
error_code = "provider_not_initialize"
|
||||||
description = "No valid model provider credentials found. " \
|
description = (
|
||||||
|
"No valid model provider credentials found. "
|
||||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderQuotaExceededError(BaseHTTPException):
|
class ProviderQuotaExceededError(BaseHTTPException):
|
||||||
error_code = 'provider_quota_exceeded'
|
error_code = "provider_quota_exceeded"
|
||||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
description = (
|
||||||
|
"Your quota for Dify Hosted OpenAI has been exhausted. "
|
||||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||||
error_code = 'model_currently_not_support'
|
error_code = "model_currently_not_support"
|
||||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestError(BaseHTTPException):
|
class CompletionRequestError(BaseHTTPException):
|
||||||
error_code = 'completion_request_error'
|
error_code = "completion_request_error"
|
||||||
description = "Completion request failed."
|
description = "Completion request failed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NoAudioUploadedError(BaseHTTPException):
|
class NoAudioUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_audio_uploaded'
|
error_code = "no_audio_uploaded"
|
||||||
description = "Please upload your audio."
|
description = "Please upload your audio."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AudioTooLargeError(BaseHTTPException):
|
class AudioTooLargeError(BaseHTTPException):
|
||||||
error_code = 'audio_too_large'
|
error_code = "audio_too_large"
|
||||||
description = "Audio size exceeded. {message}"
|
description = "Audio size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_audio_type'
|
error_code = "unsupported_audio_type"
|
||||||
description = "Audio type not allowed."
|
description = "Audio type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||||
error_code = 'provider_not_support_speech_to_text'
|
error_code = "provider_not_support_speech_to_text"
|
||||||
description = "Provider not support speech to text."
|
description = "Provider not support speech to text."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NoFileUploadedError(BaseHTTPException):
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_file_uploaded'
|
error_code = "no_file_uploaded"
|
||||||
description = "Please upload your file."
|
description = "Please upload your file."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = 'too_many_files'
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class FileTooLargeError(BaseHTTPException):
|
class FileTooLargeError(BaseHTTPException):
|
||||||
error_code = 'file_too_large'
|
error_code = "file_too_large"
|
||||||
description = "File size exceeded. {message}"
|
description = "File size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
|
@ -16,15 +16,13 @@ from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class FileApi(Resource):
|
class FileApi(Resource):
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||||
@marshal_with(file_fields)
|
@marshal_with(file_fields)
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
|
file = request.files["file"]
|
||||||
file = request.files['file']
|
|
||||||
|
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if not file.mimetype:
|
if not file.mimetype:
|
||||||
|
@ -43,4 +41,4 @@ class FileApi(Resource):
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, '/files/upload')
|
api.add_resource(FileApi, "/files/upload")
|
||||||
|
|
|
@ -17,61 +17,59 @@ from services.message_service import MessageService
|
||||||
|
|
||||||
|
|
||||||
class MessageListApi(Resource):
|
class MessageListApi(Resource):
|
||||||
feedback_fields = {
|
feedback_fields = {"rating": fields.String}
|
||||||
'rating': fields.String
|
|
||||||
}
|
|
||||||
retriever_resource_fields = {
|
retriever_resource_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'message_id': fields.String,
|
"message_id": fields.String,
|
||||||
'position': fields.Integer,
|
"position": fields.Integer,
|
||||||
'dataset_id': fields.String,
|
"dataset_id": fields.String,
|
||||||
'dataset_name': fields.String,
|
"dataset_name": fields.String,
|
||||||
'document_id': fields.String,
|
"document_id": fields.String,
|
||||||
'document_name': fields.String,
|
"document_name": fields.String,
|
||||||
'data_source_type': fields.String,
|
"data_source_type": fields.String,
|
||||||
'segment_id': fields.String,
|
"segment_id": fields.String,
|
||||||
'score': fields.Float,
|
"score": fields.Float,
|
||||||
'hit_count': fields.Integer,
|
"hit_count": fields.Integer,
|
||||||
'word_count': fields.Integer,
|
"word_count": fields.Integer,
|
||||||
'segment_position': fields.Integer,
|
"segment_position": fields.Integer,
|
||||||
'index_node_hash': fields.String,
|
"index_node_hash": fields.String,
|
||||||
'content': fields.String,
|
"content": fields.String,
|
||||||
'created_at': TimestampField
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
agent_thought_fields = {
|
agent_thought_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'chain_id': fields.String,
|
"chain_id": fields.String,
|
||||||
'message_id': fields.String,
|
"message_id": fields.String,
|
||||||
'position': fields.Integer,
|
"position": fields.Integer,
|
||||||
'thought': fields.String,
|
"thought": fields.String,
|
||||||
'tool': fields.String,
|
"tool": fields.String,
|
||||||
'tool_labels': fields.Raw,
|
"tool_labels": fields.Raw,
|
||||||
'tool_input': fields.String,
|
"tool_input": fields.String,
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'observation': fields.String,
|
"observation": fields.String,
|
||||||
'message_files': fields.List(fields.String, attribute='files')
|
"message_files": fields.List(fields.String, attribute="files"),
|
||||||
}
|
}
|
||||||
|
|
||||||
message_fields = {
|
message_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'conversation_id': fields.String,
|
"conversation_id": fields.String,
|
||||||
'inputs': fields.Raw,
|
"inputs": fields.Raw,
|
||||||
'query': fields.String,
|
"query": fields.String,
|
||||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
|
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||||
'status': fields.String,
|
"status": fields.String,
|
||||||
'error': fields.String,
|
"error": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
message_infinite_scroll_pagination_fields = {
|
message_infinite_scroll_pagination_fields = {
|
||||||
'limit': fields.Integer,
|
"limit": fields.Integer,
|
||||||
'has_more': fields.Boolean,
|
"has_more": fields.Boolean,
|
||||||
'data': fields.List(fields.Nested(message_fields))
|
"data": fields.List(fields.Nested(message_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||||
|
@ -82,14 +80,15 @@ class MessageListApi(Resource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
return MessageService.pagination_by_first_id(
|
||||||
args['conversation_id'], args['first_id'], args['limit'])
|
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
except services.errors.message.FirstMessageNotExistsError:
|
except services.errors.message.FirstMessageNotExistsError:
|
||||||
|
@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||||
except services.errors.message.MessageNotExistsError:
|
except services.errors.message.MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class MessageSuggestedApi(Resource):
|
class MessageSuggestedApi(Resource):
|
||||||
|
@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
|
||||||
user=end_user,
|
|
||||||
message_id=message_id,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API
|
|
||||||
)
|
)
|
||||||
except services.errors.message.MessageNotExistsError:
|
except services.errors.message.MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
return {'result': 'success', 'data': questions}
|
return {"result": "success", "data": questions}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(MessageListApi, '/messages')
|
api.add_resource(MessageListApi, "/messages")
|
||||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
|
||||||
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')
|
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
|
||||||
|
|
|
@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
workflow_run_fields = {
|
workflow_run_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'workflow_id': fields.String,
|
"workflow_id": fields.String,
|
||||||
'status': fields.String,
|
"status": fields.String,
|
||||||
'inputs': fields.Raw,
|
"inputs": fields.Raw,
|
||||||
'outputs': fields.Raw,
|
"outputs": fields.Raw,
|
||||||
'error': fields.String,
|
"error": fields.String,
|
||||||
'total_steps': fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
'total_tokens': fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
'created_at': fields.DateTime,
|
"created_at": fields.DateTime,
|
||||||
'finished_at': fields.DateTime,
|
"finished_at": fields.DateTime,
|
||||||
'elapsed_time': fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunDetailApi(Resource):
|
class WorkflowRunDetailApi(Resource):
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@marshal_with(workflow_run_fields)
|
@marshal_with(workflow_run_fields)
|
||||||
|
@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource):
|
||||||
|
|
||||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunApi(Resource):
|
class WorkflowRunApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
|
@ -67,20 +70,16 @@ class WorkflowRunApi(Resource):
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args.get('response_mode') == 'streaming'
|
streaming = args.get("response_mode") == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||||
user=end_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
||||||
return {
|
return {"result": "success"}
|
||||||
"result": "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(WorkflowRunApi, '/workflows/run')
|
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||||
api.add_resource(WorkflowRunDetailApi, '/workflows/run/<string:workflow_id>')
|
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
|
||||||
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||||
|
|
|
@ -16,7 +16,7 @@ from services.dataset_service import DatasetService
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
raise ValueError('Name must be between 1 to 40 characters.')
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource):
|
||||||
def get(self, tenant_id):
|
def get(self, tenant_id):
|
||||||
"""Resource for getting datasets."""
|
"""Resource for getting datasets."""
|
||||||
|
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
provider = request.args.get('provider', default="vendor")
|
provider = request.args.get("provider", default="vendor")
|
||||||
search = request.args.get('keyword', default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
tag_ids = request.args.getlist('tag_ids')
|
tag_ids = request.args.getlist("tag_ids")
|
||||||
|
|
||||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
|
||||||
tenant_id, current_user, search, tag_ids)
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(
|
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_models = configurations.get_models(
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
only_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for embedding_model in embedding_models:
|
for embedding_model in embedding_models:
|
||||||
|
@ -51,50 +45,59 @@ class DatasetListApi(DatasetApiResource):
|
||||||
|
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
for item in data:
|
for item in data:
|
||||||
if item['indexing_technique'] == 'high_quality':
|
if item["indexing_technique"] == "high_quality":
|
||||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||||
if item_model in model_names:
|
if item_model in model_names:
|
||||||
item['embedding_available'] = True
|
item["embedding_available"] = True
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = False
|
item["embedding_available"] = False
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = True
|
item["embedding_available"] = True
|
||||||
response = {
|
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||||
'data': data,
|
|
||||||
'has_more': len(datasets) == limit,
|
|
||||||
'limit': limit,
|
|
||||||
'total': total,
|
|
||||||
'page': page
|
|
||||||
}
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
def post(self, tenant_id):
|
def post(self, tenant_id):
|
||||||
"""Resource for creating datasets."""
|
"""Resource for creating datasets."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False, required=True,
|
parser.add_argument(
|
||||||
help='type is required. Name must be between 1 to 40 characters.',
|
"name",
|
||||||
type=_validate_name)
|
nullable=False,
|
||||||
parser.add_argument('indexing_technique', type=str, location='json',
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"indexing_technique",
|
||||||
|
type=str,
|
||||||
|
location="json",
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
help='Invalid indexing technique.')
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument('permission', type=str, location='json', choices=(
|
)
|
||||||
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False)
|
parser.add_argument(
|
||||||
|
"permission",
|
||||||
|
type=str,
|
||||||
|
location="json",
|
||||||
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
|
help="Invalid permission.",
|
||||||
|
required=False,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=args['name'],
|
name=args["name"],
|
||||||
indexing_technique=args['indexing_technique'],
|
indexing_technique=args["indexing_technique"],
|
||||||
account=current_user,
|
account=current_user,
|
||||||
permission=args['permission']
|
permission=args["permission"],
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
|
||||||
return marshal(dataset, dataset_detail_fields), 200
|
return marshal(dataset, dataset_detail_fields), 200
|
||||||
|
|
||||||
|
|
||||||
class DatasetApi(DatasetApiResource):
|
class DatasetApi(DatasetApiResource):
|
||||||
"""Resource for dataset."""
|
"""Resource for dataset."""
|
||||||
|
|
||||||
|
@ -118,11 +121,12 @@ class DatasetApi(DatasetApiResource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||||
return {'result': 'success'}, 204
|
return {"result": "success"}, 204
|
||||||
else:
|
else:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
except services.errors.dataset.DatasetInUseError:
|
except services.errors.dataset.DatasetInUseError:
|
||||||
raise DatasetInUseError()
|
raise DatasetInUseError()
|
||||||
|
|
||||||
api.add_resource(DatasetListApi, '/datasets')
|
|
||||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
api.add_resource(DatasetListApi, "/datasets")
|
||||||
|
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||||
|
|
|
@ -27,47 +27,40 @@ from services.file_service import FileService
|
||||||
class DocumentAddByTextApi(DatasetApiResource):
|
class DocumentAddByTextApi(DatasetApiResource):
|
||||||
"""Resource for documents."""
|
"""Resource for documents."""
|
||||||
|
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||||
def post(self, tenant_id, dataset_id):
|
def post(self, tenant_id, dataset_id):
|
||||||
"""Create document by text."""
|
"""Create document by text."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
parser.add_argument("text", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
parser.add_argument(
|
||||||
location='json')
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
)
|
||||||
location='json')
|
parser.add_argument(
|
||||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||||
location='json')
|
)
|
||||||
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset is not exist.')
|
raise ValueError("Dataset is not exist.")
|
||||||
|
|
||||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||||
raise ValueError('indexing_technique is required.')
|
raise ValueError("indexing_technique is required.")
|
||||||
|
|
||||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||||
data_source = {
|
data_source = {
|
||||||
'type': 'upload_file',
|
"type": "upload_file",
|
||||||
'info_list': {
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
'data_source_type': 'upload_file',
|
|
||||||
'file_info_list': {
|
|
||||||
'file_ids': [upload_file.id]
|
|
||||||
}
|
}
|
||||||
}
|
args["data_source"] = data_source
|
||||||
}
|
|
||||||
args['data_source'] = data_source
|
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
|
@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
document_data=args,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from='api'
|
created_from="api",
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
|
|
||||||
documents_and_batch_fields = {
|
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||||
'document': marshal(document, document_fields),
|
|
||||||
'batch': batch
|
|
||||||
}
|
|
||||||
return documents_and_batch_fields, 200
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||||
"""Resource for update documents."""
|
"""Resource for update documents."""
|
||||||
|
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
def post(self, tenant_id, dataset_id, document_id):
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
"""Update document by text."""
|
"""Update document by text."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
|
parser.add_argument("text", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
parser.add_argument(
|
||||||
location='json')
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
)
|
||||||
location='json')
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset is not exist.')
|
raise ValueError("Dataset is not exist.")
|
||||||
|
|
||||||
if args['text']:
|
if args["text"]:
|
||||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||||
data_source = {
|
data_source = {
|
||||||
'type': 'upload_file',
|
"type": "upload_file",
|
||||||
'info_list': {
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
'data_source_type': 'upload_file',
|
|
||||||
'file_info_list': {
|
|
||||||
'file_ids': [upload_file.id]
|
|
||||||
}
|
}
|
||||||
}
|
args["data_source"] = data_source
|
||||||
}
|
|
||||||
args['data_source'] = data_source
|
|
||||||
# validate args
|
# validate args
|
||||||
args['original_document_id'] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
document_data=args,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from='api'
|
created_from="api",
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
|
|
||||||
documents_and_batch_fields = {
|
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||||
'document': marshal(document, document_fields),
|
|
||||||
'batch': batch
|
|
||||||
}
|
|
||||||
return documents_and_batch_fields, 200
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentAddByFileApi(DatasetApiResource):
|
class DocumentAddByFileApi(DatasetApiResource):
|
||||||
"""Resource for documents."""
|
"""Resource for documents."""
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
|
||||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
|
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||||
def post(self, tenant_id, dataset_id):
|
def post(self, tenant_id, dataset_id):
|
||||||
"""Create document by upload file."""
|
"""Create document by upload file."""
|
||||||
args = {}
|
args = {}
|
||||||
if 'data' in request.form:
|
if "data" in request.form:
|
||||||
args = json.loads(request.form['data'])
|
args = json.loads(request.form["data"])
|
||||||
if 'doc_form' not in args:
|
if "doc_form" not in args:
|
||||||
args['doc_form'] = 'text_model'
|
args["doc_form"] = "text_model"
|
||||||
if 'doc_language' not in args:
|
if "doc_language" not in args:
|
||||||
args['doc_language'] = 'English'
|
args["doc_language"] = "English"
|
||||||
# get dataset info
|
# get dataset info
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset is not exist.')
|
raise ValueError("Dataset is not exist.")
|
||||||
if not dataset.indexing_technique and not args.get('indexing_technique'):
|
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||||
raise ValueError('indexing_technique is required.')
|
raise ValueError("indexing_technique is required.")
|
||||||
|
|
||||||
# save file info
|
# save file info
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
|
|
||||||
upload_file = FileService.upload_file(file, current_user)
|
upload_file = FileService.upload_file(file, current_user)
|
||||||
data_source = {
|
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||||
'type': 'upload_file',
|
args["data_source"] = data_source
|
||||||
'info_list': {
|
|
||||||
'file_info_list': {
|
|
||||||
'file_ids': [upload_file.id]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args['data_source'] = data_source
|
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
|
@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
document_data=args,
|
||||||
account=dataset.created_by_account,
|
account=dataset.created_by_account,
|
||||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from='api'
|
created_from="api",
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
documents_and_batch_fields = {
|
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||||
'document': marshal(document, document_fields),
|
|
||||||
'batch': batch
|
|
||||||
}
|
|
||||||
return documents_and_batch_fields, 200
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
"""Resource for update documents."""
|
"""Resource for update documents."""
|
||||||
|
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
def post(self, tenant_id, dataset_id, document_id):
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
"""Update document by upload file."""
|
"""Update document by upload file."""
|
||||||
args = {}
|
args = {}
|
||||||
if 'data' in request.form:
|
if "data" in request.form:
|
||||||
args = json.loads(request.form['data'])
|
args = json.loads(request.form["data"])
|
||||||
if 'doc_form' not in args:
|
if "doc_form" not in args:
|
||||||
args['doc_form'] = 'text_model'
|
args["doc_form"] = "text_model"
|
||||||
if 'doc_language' not in args:
|
if "doc_language" not in args:
|
||||||
args['doc_language'] = 'English'
|
args["doc_language"] = "English"
|
||||||
|
|
||||||
# get dataset info
|
# get dataset info
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset is not exist.')
|
raise ValueError("Dataset is not exist.")
|
||||||
if 'file' in request.files:
|
if "file" in request.files:
|
||||||
# save file info
|
# save file info
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
|
|
||||||
upload_file = FileService.upload_file(file, current_user)
|
upload_file = FileService.upload_file(file, current_user)
|
||||||
data_source = {
|
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||||
'type': 'upload_file',
|
args["data_source"] = data_source
|
||||||
'info_list': {
|
|
||||||
'file_info_list': {
|
|
||||||
'file_ids': [upload_file.id]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args['data_source'] = data_source
|
|
||||||
# validate args
|
# validate args
|
||||||
args['original_document_id'] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
document_data=args,
|
||||||
account=dataset.created_by_account,
|
account=dataset.created_by_account,
|
||||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from='api'
|
created_from="api",
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
documents_and_batch_fields = {
|
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||||
'document': marshal(document, document_fields),
|
|
||||||
'batch': batch
|
|
||||||
}
|
|
||||||
return documents_and_batch_fields, 200
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource):
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
|
|
||||||
# get dataset info
|
# get dataset info
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset is not exist.')
|
raise ValueError("Dataset is not exist.")
|
||||||
|
|
||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
|
||||||
|
@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource):
|
||||||
# delete document
|
# delete document
|
||||||
DocumentService.delete_document(document)
|
DocumentService.delete_document(document)
|
||||||
except services.errors.document.DocumentIndexingError:
|
except services.errors.document.DocumentIndexingError:
|
||||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentListApi(DatasetApiResource):
|
class DocumentListApi(DatasetApiResource):
|
||||||
def get(self, tenant_id, dataset_id):
|
def get(self, tenant_id, dataset_id):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
page = request.args.get('page', default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get('limit', default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
search = request.args.get('keyword', default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
query = Document.query.filter_by(
|
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||||
dataset_id=str(dataset_id), tenant_id=tenant_id)
|
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search = f'%{search}%'
|
search = f"%{search}%"
|
||||||
query = query.filter(Document.name.like(search))
|
query = query.filter(Document.name.like(search))
|
||||||
|
|
||||||
query = query.order_by(desc(Document.created_at))
|
query = query.order_by(desc(Document.created_at))
|
||||||
|
|
||||||
paginated_documents = query.paginate(
|
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
|
||||||
documents = paginated_documents.items
|
documents = paginated_documents.items
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
'data': marshal(documents, document_fields),
|
"data": marshal(documents, document_fields),
|
||||||
'has_more': len(documents) == limit,
|
"has_more": len(documents) == limit,
|
||||||
'limit': limit,
|
"limit": limit,
|
||||||
'total': paginated_documents.total,
|
"total": paginated_documents.total,
|
||||||
'page': page
|
"page": page,
|
||||||
}
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||||
batch = str(batch)
|
batch = str(batch)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# get documents
|
# get documents
|
||||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||||
if not documents:
|
if not documents:
|
||||||
raise NotFound('Documents not found.')
|
raise NotFound("Documents not found.")
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
completed_segments = DocumentSegment.query.filter(
|
||||||
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.document_id == str(document.id),
|
DocumentSegment.document_id == str(document.id),
|
||||||
DocumentSegment.status != 're_segment').count()
|
DocumentSegment.status != "re_segment",
|
||||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
).count()
|
||||||
DocumentSegment.status != 're_segment').count()
|
total_segments = DocumentSegment.query.filter(
|
||||||
|
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||||
|
).count()
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
if document.is_paused:
|
if document.is_paused:
|
||||||
document.indexing_status = 'paused'
|
document.indexing_status = "paused"
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
data = {
|
data = {"data": documents_status}
|
||||||
'data': documents_status
|
|
||||||
}
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
|
api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||||
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
|
api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file")
|
||||||
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
|
api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||||
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
|
api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file")
|
||||||
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||||
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
|
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||||
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
|
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||||
|
|
|
@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class NoFileUploadedError(BaseHTTPException):
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_file_uploaded'
|
error_code = "no_file_uploaded"
|
||||||
description = "Please upload your file."
|
description = "Please upload your file."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = 'too_many_files'
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class FileTooLargeError(BaseHTTPException):
|
class FileTooLargeError(BaseHTTPException):
|
||||||
error_code = 'file_too_large'
|
error_code = "file_too_large"
|
||||||
description = "File size exceeded. {message}"
|
description = "File size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||||
error_code = 'high_quality_dataset_only'
|
error_code = "high_quality_dataset_only"
|
||||||
description = "Current operation only supports 'high-quality' datasets."
|
description = "Current operation only supports 'high-quality' datasets."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DatasetNotInitializedError(BaseHTTPException):
|
class DatasetNotInitializedError(BaseHTTPException):
|
||||||
error_code = 'dataset_not_initialized'
|
error_code = "dataset_not_initialized"
|
||||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||||
error_code = 'archived_document_immutable'
|
error_code = "archived_document_immutable"
|
||||||
description = "The archived document is not editable."
|
description = "The archived document is not editable."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class DatasetNameDuplicateError(BaseHTTPException):
|
class DatasetNameDuplicateError(BaseHTTPException):
|
||||||
error_code = 'dataset_name_duplicate'
|
error_code = "dataset_name_duplicate"
|
||||||
description = "The dataset name already exists. Please modify your dataset name."
|
description = "The dataset name already exists. Please modify your dataset name."
|
||||||
code = 409
|
code = 409
|
||||||
|
|
||||||
|
|
||||||
class InvalidActionError(BaseHTTPException):
|
class InvalidActionError(BaseHTTPException):
|
||||||
error_code = 'invalid_action'
|
error_code = "invalid_action"
|
||||||
description = "Invalid action."
|
description = "Invalid action."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||||
error_code = 'document_already_finished'
|
error_code = "document_already_finished"
|
||||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingError(BaseHTTPException):
|
class DocumentIndexingError(BaseHTTPException):
|
||||||
error_code = 'document_indexing'
|
error_code = "document_indexing"
|
||||||
description = "The document is being processed and cannot be edited."
|
description = "The document is being processed and cannot be edited."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class InvalidMetadataError(BaseHTTPException):
|
class InvalidMetadataError(BaseHTTPException):
|
||||||
error_code = 'invalid_metadata'
|
error_code = "invalid_metadata"
|
||||||
description = "The metadata content is incorrect. Please check and verify."
|
description = "The metadata content is incorrect. Please check and verify."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DatasetInUseError(BaseHTTPException):
|
class DatasetInUseError(BaseHTTPException):
|
||||||
error_code = 'dataset_in_use'
|
error_code = "dataset_in_use"
|
||||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||||
code = 409
|
code = 409
|
||||||
|
|
|
@ -21,52 +21,47 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
|
||||||
class SegmentApi(DatasetApiResource):
|
class SegmentApi(DatasetApiResource):
|
||||||
"""Resource for segments."""
|
"""Resource for segments."""
|
||||||
|
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
@cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset')
|
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||||
def post(self, tenant_id, dataset_id, document_id):
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
"""Create single segment."""
|
"""Create single segment."""
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args['segments'] is not None:
|
if args["segments"] is not None:
|
||||||
for args_item in args['segments']:
|
for args_item in args["segments"]:
|
||||||
SegmentService.segment_create_args_validate(args_item, document)
|
SegmentService.segment_create_args_validate(args_item, document)
|
||||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||||
return {
|
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
'data': marshal(segments, segment_fields),
|
|
||||||
'doc_form': document.doc_form
|
|
||||||
}, 200
|
|
||||||
else:
|
else:
|
||||||
return {"error": "Segemtns is required"}, 400
|
return {"error": "Segemtns is required"}, 400
|
||||||
|
|
||||||
|
@ -75,61 +70,53 @@ class SegmentApi(DatasetApiResource):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('status', type=str,
|
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||||
action='append', default=[], location='args')
|
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
status_list = args['status']
|
status_list = args["status"]
|
||||||
keyword = args['keyword']
|
keyword = args["keyword"]
|
||||||
|
|
||||||
query = DocumentSegment.query.filter(
|
query = DocumentSegment.query.filter(
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if status_list:
|
if status_list:
|
||||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||||
|
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
segments = query.order_by(DocumentSegment.position).all()
|
segments = query.order_by(DocumentSegment.position).all()
|
||||||
return {
|
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200
|
||||||
'data': marshal(segments, segment_fields),
|
|
||||||
'doc_form': document.doc_form,
|
|
||||||
'total': total
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetSegmentApi(DatasetApiResource):
|
class DatasetSegmentApi(DatasetApiResource):
|
||||||
|
@ -137,48 +124,41 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check user's model setting
|
# check user's model setting
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
).first()
|
).first()
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound('Segment not found.')
|
raise NotFound("Segment not found.")
|
||||||
SegmentService.delete_segment(segment, document, dataset)
|
SegmentService.delete_segment(segment, document, dataset)
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
tenant_id = str(tenant_id)
|
tenant_id = str(tenant_id)
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound("Dataset not found.")
|
||||||
# check user's model setting
|
# check user's model setting
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound("Document not found.")
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == "high_quality":
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
|
@ -186,35 +166,34 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
"in the Settings -> Model Provider.")
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
|
||||||
).first()
|
).first()
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound('Segment not found.')
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('segment', type=dict, required=False, nullable=True, location='json')
|
parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
SegmentService.segment_create_args_validate(args['segment'], document)
|
SegmentService.segment_create_args_validate(args["segment"], document)
|
||||||
segment = SegmentService.update_segment(args['segment'], segment, document, dataset)
|
segment = SegmentService.update_segment(args["segment"], segment, document, dataset)
|
||||||
return {
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
'data': marshal(segment, segment_fields),
|
|
||||||
'doc_form': document.doc_form
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||||
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
api.add_resource(
|
||||||
|
DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>"
|
||||||
|
)
|
||||||
|
|
|
@ -13,4 +13,4 @@ class IndexApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(IndexApi, '/')
|
api.add_resource(IndexApi, "/")
|
||||||
|
|
|
@ -21,9 +21,10 @@ class WhereisUserArg(Enum):
|
||||||
"""
|
"""
|
||||||
Enum for whereis_user_arg.
|
Enum for whereis_user_arg.
|
||||||
"""
|
"""
|
||||||
QUERY = 'query'
|
|
||||||
JSON = 'json'
|
QUERY = "query"
|
||||||
FORM = 'form'
|
JSON = "json"
|
||||||
|
FORM = "form"
|
||||||
|
|
||||||
|
|
||||||
class FetchUserArg(BaseModel):
|
class FetchUserArg(BaseModel):
|
||||||
|
@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||||
def decorator(view_func):
|
def decorator(view_func):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args, **kwargs):
|
def decorated_view(*args, **kwargs):
|
||||||
api_token = validate_and_get_api_token('app')
|
api_token = validate_and_get_api_token("app")
|
||||||
|
|
||||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||||
if not app_model:
|
if not app_model:
|
||||||
raise Forbidden("The app no longer exists.")
|
raise Forbidden("The app no longer exists.")
|
||||||
|
|
||||||
if app_model.status != 'normal':
|
if app_model.status != "normal":
|
||||||
raise Forbidden("The app's status is abnormal.")
|
raise Forbidden("The app's status is abnormal.")
|
||||||
|
|
||||||
if not app_model.enable_api:
|
if not app_model.enable_api:
|
||||||
|
@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||||
if tenant.status == TenantStatus.ARCHIVE:
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
raise Forbidden("The workspace's status is archived.")
|
raise Forbidden("The workspace's status is archived.")
|
||||||
|
|
||||||
kwargs['app_model'] = app_model
|
kwargs["app_model"] = app_model
|
||||||
|
|
||||||
if fetch_user_arg:
|
if fetch_user_arg:
|
||||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||||
user_id = request.args.get('user')
|
user_id = request.args.get("user")
|
||||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||||
user_id = request.get_json().get('user')
|
user_id = request.get_json().get("user")
|
||||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||||
user_id = request.form.get('user')
|
user_id = request.form.get("user")
|
||||||
else:
|
else:
|
||||||
# use default-user
|
# use default-user
|
||||||
user_id = None
|
user_id = None
|
||||||
|
@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||||
if user_id:
|
if user_id:
|
||||||
user_id = str(user_id)
|
user_id = str(user_id)
|
||||||
|
|
||||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||||
|
|
||||||
return view_func(*args, **kwargs)
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_view
|
return decorated_view
|
||||||
|
|
||||||
if view is None:
|
if view is None:
|
||||||
|
@ -81,9 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str,
|
def cloud_edition_billing_resource_check(
|
||||||
api_token_type: str,
|
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
|
||||||
error_msg: str = "You have reached the limit of your subscription."):
|
):
|
||||||
def interceptor(view):
|
def interceptor(view):
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
|
@ -95,33 +97,37 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||||
vector_space = features.vector_space
|
vector_space = features.vector_space
|
||||||
documents_upload_quota = features.documents_upload_quota
|
documents_upload_quota = features.documents_upload_quota
|
||||||
|
|
||||||
if resource == 'members' and 0 < members.limit <= members.size:
|
if resource == "members" and 0 < members.limit <= members.size:
|
||||||
raise Forbidden(error_msg)
|
raise Forbidden(error_msg)
|
||||||
elif resource == 'apps' and 0 < apps.limit <= apps.size:
|
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||||
raise Forbidden(error_msg)
|
raise Forbidden(error_msg)
|
||||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||||
raise Forbidden(error_msg)
|
raise Forbidden(error_msg)
|
||||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||||
raise Forbidden(error_msg)
|
raise Forbidden(error_msg)
|
||||||
else:
|
else:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str,
|
def cloud_edition_billing_knowledge_limit_check(
|
||||||
|
resource: str,
|
||||||
api_token_type: str,
|
api_token_type: str,
|
||||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
|
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||||
|
):
|
||||||
def interceptor(view):
|
def interceptor(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
features = FeatureService.get_features(api_token.tenant_id)
|
features = FeatureService.get_features(api_token.tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == 'add_segment':
|
if resource == "add_segment":
|
||||||
if features.billing.subscription.plan == 'sandbox':
|
if features.billing.subscription.plan == "sandbox":
|
||||||
raise Forbidden(error_msg)
|
raise Forbidden(error_msg)
|
||||||
else:
|
else:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -132,17 +138,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str,
|
||||||
|
|
||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def validate_dataset_token(view=None):
|
def validate_dataset_token(view=None):
|
||||||
def decorator(view):
|
def decorator(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
api_token = validate_and_get_api_token('dataset')
|
api_token = validate_and_get_api_token("dataset")
|
||||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
tenant_account_join = (
|
||||||
.filter(Tenant.id == api_token.tenant_id) \
|
db.session.query(Tenant, TenantAccountJoin)
|
||||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
.filter(Tenant.id == api_token.tenant_id)
|
||||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||||
.filter(Tenant.status == TenantStatus.NORMAL) \
|
.filter(TenantAccountJoin.role.in_(["owner"]))
|
||||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
.filter(Tenant.status == TenantStatus.NORMAL)
|
||||||
|
.one_or_none()
|
||||||
|
) # TODO: only owner information is required, so only one is returned.
|
||||||
if tenant_account_join:
|
if tenant_account_join:
|
||||||
tenant, ta = tenant_account_join
|
tenant, ta = tenant_account_join
|
||||||
account = Account.query.filter_by(id=ta.account_id).first()
|
account = Account.query.filter_by(id=ta.account_id).first()
|
||||||
|
@ -156,6 +165,7 @@ def validate_dataset_token(view=None):
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("Tenant does not exist.")
|
raise Unauthorized("Tenant does not exist.")
|
||||||
return view(api_token.tenant_id, *args, **kwargs)
|
return view(api_token.tenant_id, *args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
if view:
|
if view:
|
||||||
|
@ -170,20 +180,24 @@ def validate_and_get_api_token(scope=None):
|
||||||
"""
|
"""
|
||||||
Validate and get API token.
|
Validate and get API token.
|
||||||
"""
|
"""
|
||||||
auth_header = request.headers.get('Authorization')
|
auth_header = request.headers.get("Authorization")
|
||||||
if auth_header is None or ' ' not in auth_header:
|
if auth_header is None or " " not in auth_header:
|
||||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||||
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
auth_scheme = auth_scheme.lower()
|
auth_scheme = auth_scheme.lower()
|
||||||
|
|
||||||
if auth_scheme != 'bearer':
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||||
|
|
||||||
api_token = db.session.query(ApiToken).filter(
|
api_token = (
|
||||||
|
db.session.query(ApiToken)
|
||||||
|
.filter(
|
||||||
ApiToken.token == auth_token,
|
ApiToken.token == auth_token,
|
||||||
ApiToken.type == scope,
|
ApiToken.type == scope,
|
||||||
).first()
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not api_token:
|
if not api_token:
|
||||||
raise Unauthorized("Access token is invalid")
|
raise Unauthorized("Access token is invalid")
|
||||||
|
@ -199,23 +213,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||||
Create or update session terminal based on user ID.
|
Create or update session terminal based on user ID.
|
||||||
"""
|
"""
|
||||||
if not user_id:
|
if not user_id:
|
||||||
user_id = 'DEFAULT-USER'
|
user_id = "DEFAULT-USER"
|
||||||
|
|
||||||
end_user = db.session.query(EndUser) \
|
end_user = (
|
||||||
|
db.session.query(EndUser)
|
||||||
.filter(
|
.filter(
|
||||||
EndUser.tenant_id == app_model.tenant_id,
|
EndUser.tenant_id == app_model.tenant_id,
|
||||||
EndUser.app_id == app_model.id,
|
EndUser.app_id == app_model.id,
|
||||||
EndUser.session_id == user_id,
|
EndUser.session_id == user_id,
|
||||||
EndUser.type == 'service_api'
|
EndUser.type == "service_api",
|
||||||
).first()
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if end_user is None:
|
if end_user is None:
|
||||||
end_user = EndUser(
|
end_user = EndUser(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
type='service_api',
|
type="service_api",
|
||||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||||
session_id=user_id
|
session_id=user_id,
|
||||||
)
|
)
|
||||||
db.session.add(end_user)
|
db.session.add(end_user)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
bp = Blueprint('web', __name__, url_prefix='/api')
|
bp = Blueprint("web", __name__, url_prefix="/api")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,33 +10,32 @@ from services.app_service import AppService
|
||||||
|
|
||||||
class AppParameterApi(WebApiResource):
|
class AppParameterApi(WebApiResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
variable_fields = {
|
variable_fields = {
|
||||||
'key': fields.String,
|
"key": fields.String,
|
||||||
'name': fields.String,
|
"name": fields.String,
|
||||||
'description': fields.String,
|
"description": fields.String,
|
||||||
'type': fields.String,
|
"type": fields.String,
|
||||||
'default': fields.String,
|
"default": fields.String,
|
||||||
'max_length': fields.Integer,
|
"max_length": fields.Integer,
|
||||||
'options': fields.List(fields.String)
|
"options": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
system_parameters_fields = {
|
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||||
'image_file_size_limit': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
parameters_fields = {
|
parameters_fields = {
|
||||||
'opening_statement': fields.String,
|
"opening_statement": fields.String,
|
||||||
'suggested_questions': fields.Raw,
|
"suggested_questions": fields.Raw,
|
||||||
'suggested_questions_after_answer': fields.Raw,
|
"suggested_questions_after_answer": fields.Raw,
|
||||||
'speech_to_text': fields.Raw,
|
"speech_to_text": fields.Raw,
|
||||||
'text_to_speech': fields.Raw,
|
"text_to_speech": fields.Raw,
|
||||||
'retriever_resource': fields.Raw,
|
"retriever_resource": fields.Raw,
|
||||||
'annotation_reply': fields.Raw,
|
"annotation_reply": fields.Raw,
|
||||||
'more_like_this': fields.Raw,
|
"more_like_this": fields.Raw,
|
||||||
'user_input_form': fields.Raw,
|
"user_input_form": fields.Raw,
|
||||||
'sensitive_word_avoidance': fields.Raw,
|
"sensitive_word_avoidance": fields.Raw,
|
||||||
'file_upload': fields.Raw,
|
"file_upload": fields.Raw,
|
||||||
'system_parameters': fields.Nested(system_parameters_fields)
|
"system_parameters": fields.Nested(system_parameters_fields),
|
||||||
}
|
}
|
||||||
|
|
||||||
@marshal_with(parameters_fields)
|
@marshal_with(parameters_fields)
|
||||||
|
@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource):
|
||||||
app_model_config = app_model.app_model_config
|
app_model_config = app_model.app_model_config
|
||||||
features_dict = app_model_config.to_dict()
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
user_input_form = features_dict.get('user_input_form', [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'opening_statement': features_dict.get('opening_statement'),
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
'suggested_questions': features_dict.get('suggested_questions', []),
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
|
"suggested_questions_after_answer": features_dict.get(
|
||||||
{"enabled": False}),
|
"suggested_questions_after_answer", {"enabled": False}
|
||||||
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
|
),
|
||||||
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
'user_input_form': user_input_form,
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
|
"user_input_form": user_input_form,
|
||||||
{"enabled": False, "type": "", "configs": []}),
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
'file_upload': features_dict.get('file_upload', {"image": {
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"number_limits": 3,
|
"number_limits": 3,
|
||||||
"detail": "high",
|
"detail": "high",
|
||||||
"transfer_methods": ["remote_url", "local_file"]
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
}}),
|
|
||||||
'system_parameters': {
|
|
||||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,5 +90,5 @@ class AppMeta(WebApiResource):
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppParameterApi, '/parameters')
|
api.add_resource(AppParameterApi, "/parameters")
|
||||||
api.add_resource(AppMeta, '/meta')
|
api.add_resource(AppMeta, "/meta")
|
||||||
|
|
|
@ -31,14 +31,10 @@ from services.errors.audio import (
|
||||||
|
|
||||||
class AudioApi(WebApiResource):
|
class AudioApi(WebApiResource):
|
||||||
def post(self, app_model: App, end_user):
|
def post(self, app_model: App, end_user):
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AudioService.transcript_asr(
|
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||||
app_model=app_model,
|
|
||||||
file=file,
|
|
||||||
end_user=end_user
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
@ -70,34 +66,36 @@ class AudioApi(WebApiResource):
|
||||||
class TextApi(WebApiResource):
|
class TextApi(WebApiResource):
|
||||||
def post(self, app_model: App, end_user):
|
def post(self, app_model: App, end_user):
|
||||||
from flask_restful import reqparse
|
from flask_restful import reqparse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument('text', type=str, location='json')
|
parser.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument("streaming", type=bool, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id', None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get('text', None)
|
text = args.get("text", None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (
|
||||||
|
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict
|
||||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
):
|
||||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
|
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = args.get('voice') if args.get(
|
voice = (
|
||||||
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
args.get("voice")
|
||||||
|
if args.get("voice")
|
||||||
|
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
|
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model,
|
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||||
message_id=message_id,
|
|
||||||
end_user=end_user.external_user_id,
|
|
||||||
voice=voice,
|
|
||||||
text=text
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -127,5 +125,5 @@ class TextApi(WebApiResource):
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AudioApi, '/audio-to-text')
|
api.add_resource(AudioApi, "/audio-to-text")
|
||||||
api.add_resource(TextApi, '/text-to-audio')
|
api.add_resource(TextApi, "/text-to-audio")
|
||||||
|
|
|
@ -28,30 +28,25 @@ from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
# define completion api for user
|
# define completion api for user
|
||||||
class CompletionApi(WebApiResource):
|
class CompletionApi(WebApiResource):
|
||||||
|
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, location='json', default='')
|
parser.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming
|
||||||
user=end_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -79,12 +74,12 @@ class CompletionApi(WebApiResource):
|
||||||
|
|
||||||
class CompletionStopApi(WebApiResource):
|
class CompletionStopApi(WebApiResource):
|
||||||
def post(self, app_model, end_user, task_id):
|
def post(self, app_model, end_user, task_id):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
class ChatApi(WebApiResource):
|
class ChatApi(WebApiResource):
|
||||||
|
@ -94,25 +89,21 @@ class ChatApi(WebApiResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument('query', type=str, required=True, location='json')
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument('files', type=list, required=False, location='json')
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
args['auto_generate_name'] = False
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming
|
||||||
user=end_user,
|
|
||||||
args=args,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
streaming=streaming
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -146,10 +137,10 @@ class ChatStopApi(WebApiResource):
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
||||||
return {'result': 'success'}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CompletionApi, '/completion-messages')
|
api.add_resource(CompletionApi, "/completion-messages")
|
||||||
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
|
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
|
||||||
api.add_resource(ChatApi, '/chat-messages')
|
api.add_resource(ChatApi, "/chat-messages")
|
||||||
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
|
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")
|
||||||
|
|
|
@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
|
|
||||||
class ConversationListApi(WebApiResource):
|
class ConversationListApi(WebApiResource):
|
||||||
|
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model, end_user):
|
def get(self, app_model, end_user):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
@ -23,26 +22,32 @@ class ConversationListApi(WebApiResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
parser.add_argument(
|
||||||
required=False, default='-updated_at', location='args')
|
"sort_by",
|
||||||
|
type=str,
|
||||||
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
|
required=False,
|
||||||
|
default="-updated_at",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pinned = None
|
pinned = None
|
||||||
if 'pinned' in args and args['pinned'] is not None:
|
if "pinned" in args and args["pinned"] is not None:
|
||||||
pinned = True if args['pinned'] == 'true' else False
|
pinned = True if args["pinned"] == "true" else False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return WebConversationService.pagination_by_last_id(
|
return WebConversationService.pagination_by_last_id(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
last_id=args['last_id'],
|
last_id=args["last_id"],
|
||||||
limit=args['limit'],
|
limit=args["limit"],
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
pinned=pinned,
|
pinned=pinned,
|
||||||
sort_by=args['sort_by']
|
sort_by=args["sort_by"],
|
||||||
)
|
)
|
||||||
except LastConversationNotExistsError:
|
except LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
@ -65,7 +70,6 @@ class ConversationApi(WebApiResource):
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenameApi(WebApiResource):
|
class ConversationRenameApi(WebApiResource):
|
||||||
|
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model, end_user, c_id):
|
def post(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
@ -75,24 +79,17 @@ class ConversationRenameApi(WebApiResource):
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=False, location='json')
|
parser.add_argument("name", type=str, required=False, location="json")
|
||||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.rename(
|
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||||
app_model,
|
|
||||||
conversation_id,
|
|
||||||
end_user,
|
|
||||||
args['name'],
|
|
||||||
args['auto_generate']
|
|
||||||
)
|
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
class ConversationPinApi(WebApiResource):
|
class ConversationPinApi(WebApiResource):
|
||||||
|
|
||||||
def patch(self, app_model, end_user, c_id):
|
def patch(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||||
|
@ -120,8 +117,8 @@ class ConversationUnPinApi(WebApiResource):
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='web_conversation_name')
|
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="web_conversation_name")
|
||||||
api.add_resource(ConversationListApi, '/conversations')
|
api.add_resource(ConversationListApi, "/conversations")
|
||||||
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>')
|
api.add_resource(ConversationApi, "/conversations/<uuid:c_id>")
|
||||||
api.add_resource(ConversationPinApi, '/conversations/<uuid:c_id>/pin')
|
api.add_resource(ConversationPinApi, "/conversations/<uuid:c_id>/pin")
|
||||||
api.add_resource(ConversationUnPinApi, '/conversations/<uuid:c_id>/unpin')
|
api.add_resource(ConversationUnPinApi, "/conversations/<uuid:c_id>/unpin")
|
||||||
|
|
|
@ -2,122 +2,126 @@ from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
class AppUnavailableError(BaseHTTPException):
|
class AppUnavailableError(BaseHTTPException):
|
||||||
error_code = 'app_unavailable'
|
error_code = "app_unavailable"
|
||||||
description = "App unavailable, please check your app configurations."
|
description = "App unavailable, please check your app configurations."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotCompletionAppError(BaseHTTPException):
|
class NotCompletionAppError(BaseHTTPException):
|
||||||
error_code = 'not_completion_app'
|
error_code = "not_completion_app"
|
||||||
description = "Please check if your Completion app mode matches the right API route."
|
description = "Please check if your Completion app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotChatAppError(BaseHTTPException):
|
class NotChatAppError(BaseHTTPException):
|
||||||
error_code = 'not_chat_app'
|
error_code = "not_chat_app"
|
||||||
description = "Please check if your app mode matches the right API route."
|
description = "Please check if your app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NotWorkflowAppError(BaseHTTPException):
|
class NotWorkflowAppError(BaseHTTPException):
|
||||||
error_code = 'not_workflow_app'
|
error_code = "not_workflow_app"
|
||||||
description = "Please check if your Workflow app mode matches the right API route."
|
description = "Please check if your Workflow app mode matches the right API route."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ConversationCompletedError(BaseHTTPException):
|
class ConversationCompletedError(BaseHTTPException):
|
||||||
error_code = 'conversation_completed'
|
error_code = "conversation_completed"
|
||||||
description = "The conversation has ended. Please start a new conversation."
|
description = "The conversation has ended. Please start a new conversation."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotInitializeError(BaseHTTPException):
|
class ProviderNotInitializeError(BaseHTTPException):
|
||||||
error_code = 'provider_not_initialize'
|
error_code = "provider_not_initialize"
|
||||||
description = "No valid model provider credentials found. " \
|
description = (
|
||||||
|
"No valid model provider credentials found. "
|
||||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderQuotaExceededError(BaseHTTPException):
|
class ProviderQuotaExceededError(BaseHTTPException):
|
||||||
error_code = 'provider_quota_exceeded'
|
error_code = "provider_quota_exceeded"
|
||||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
description = (
|
||||||
|
"Your quota for Dify Hosted OpenAI has been exhausted. "
|
||||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||||
|
)
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||||
error_code = 'model_currently_not_support'
|
error_code = "model_currently_not_support"
|
||||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestError(BaseHTTPException):
|
class CompletionRequestError(BaseHTTPException):
|
||||||
error_code = 'completion_request_error'
|
error_code = "completion_request_error"
|
||||||
description = "Completion request failed."
|
description = "Completion request failed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||||
error_code = 'app_more_like_this_disabled'
|
error_code = "app_more_like_this_disabled"
|
||||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
error_code = "app_suggested_questions_after_answer_disabled"
|
||||||
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
|
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
class NoAudioUploadedError(BaseHTTPException):
|
class NoAudioUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_audio_uploaded'
|
error_code = "no_audio_uploaded"
|
||||||
description = "Please upload your audio."
|
description = "Please upload your audio."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class AudioTooLargeError(BaseHTTPException):
|
class AudioTooLargeError(BaseHTTPException):
|
||||||
error_code = 'audio_too_large'
|
error_code = "audio_too_large"
|
||||||
description = "Audio size exceeded. {message}"
|
description = "Audio size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_audio_type'
|
error_code = "unsupported_audio_type"
|
||||||
description = "Audio type not allowed."
|
description = "Audio type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||||
error_code = 'provider_not_support_speech_to_text'
|
error_code = "provider_not_support_speech_to_text"
|
||||||
description = "Provider not support speech to text."
|
description = "Provider not support speech to text."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class NoFileUploadedError(BaseHTTPException):
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
error_code = 'no_file_uploaded'
|
error_code = "no_file_uploaded"
|
||||||
description = "Please upload your file."
|
description = "Please upload your file."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = 'too_many_files'
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class FileTooLargeError(BaseHTTPException):
|
class FileTooLargeError(BaseHTTPException):
|
||||||
error_code = 'file_too_large'
|
error_code = "file_too_large"
|
||||||
description = "File size exceeded. {message}"
|
description = "File size exceeded. {message}"
|
||||||
code = 413
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(BaseHTTPException):
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
error_code = 'unsupported_file_type'
|
error_code = "unsupported_file_type"
|
||||||
description = "File type not allowed."
|
description = "File type not allowed."
|
||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class WebSSOAuthRequiredError(BaseHTTPException):
|
class WebSSOAuthRequiredError(BaseHTTPException):
|
||||||
error_code = 'web_sso_auth_required'
|
error_code = "web_sso_auth_required"
|
||||||
description = "Web SSO authentication required."
|
description = "Web SSO authentication required."
|
||||||
code = 401
|
code = 401
|
||||||
|
|
|
@ -9,4 +9,4 @@ class SystemFeatureApi(Resource):
|
||||||
return FeatureService.get_system_features().model_dump()
|
return FeatureService.get_system_features().model_dump()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(SystemFeatureApi, '/system-features')
|
api.add_resource(SystemFeatureApi, "/system-features")
|
||||||
|
|
|
@ -10,14 +10,13 @@ from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class FileApi(WebApiResource):
|
class FileApi(WebApiResource):
|
||||||
|
|
||||||
@marshal_with(file_fields)
|
@marshal_with(file_fields)
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files["file"]
|
||||||
|
|
||||||
# check file
|
# check file
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
|
@ -32,4 +31,4 @@ class FileApi(WebApiResource):
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, '/files/upload')
|
api.add_resource(FileApi, "/files/upload")
|
||||||
|
|
|
@ -33,48 +33,46 @@ from services.message_service import MessageService
|
||||||
|
|
||||||
|
|
||||||
class MessageListApi(WebApiResource):
|
class MessageListApi(WebApiResource):
|
||||||
feedback_fields = {
|
feedback_fields = {"rating": fields.String}
|
||||||
'rating': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
retriever_resource_fields = {
|
retriever_resource_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'message_id': fields.String,
|
"message_id": fields.String,
|
||||||
'position': fields.Integer,
|
"position": fields.Integer,
|
||||||
'dataset_id': fields.String,
|
"dataset_id": fields.String,
|
||||||
'dataset_name': fields.String,
|
"dataset_name": fields.String,
|
||||||
'document_id': fields.String,
|
"document_id": fields.String,
|
||||||
'document_name': fields.String,
|
"document_name": fields.String,
|
||||||
'data_source_type': fields.String,
|
"data_source_type": fields.String,
|
||||||
'segment_id': fields.String,
|
"segment_id": fields.String,
|
||||||
'score': fields.Float,
|
"score": fields.Float,
|
||||||
'hit_count': fields.Integer,
|
"hit_count": fields.Integer,
|
||||||
'word_count': fields.Integer,
|
"word_count": fields.Integer,
|
||||||
'segment_position': fields.Integer,
|
"segment_position": fields.Integer,
|
||||||
'index_node_hash': fields.String,
|
"index_node_hash": fields.String,
|
||||||
'content': fields.String,
|
"content": fields.String,
|
||||||
'created_at': TimestampField
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
message_fields = {
|
message_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'conversation_id': fields.String,
|
"conversation_id": fields.String,
|
||||||
'inputs': fields.Raw,
|
"inputs": fields.Raw,
|
||||||
'query': fields.String,
|
"query": fields.String,
|
||||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
'created_at': TimestampField,
|
"created_at": TimestampField,
|
||||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
|
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||||
'status': fields.String,
|
"status": fields.String,
|
||||||
'error': fields.String,
|
"error": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
message_infinite_scroll_pagination_fields = {
|
message_infinite_scroll_pagination_fields = {
|
||||||
'limit': fields.Integer,
|
"limit": fields.Integer,
|
||||||
'has_more': fields.Boolean,
|
"has_more": fields.Boolean,
|
||||||
'data': fields.List(fields.Nested(message_fields))
|
"data": fields.List(fields.Nested(message_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
|
@ -84,14 +82,15 @@ class MessageListApi(WebApiResource):
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
return MessageService.pagination_by_first_id(
|
||||||
args['conversation_id'], args['first_id'], args['limit'])
|
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
except services.errors.message.FirstMessageNotExistsError:
|
except services.errors.message.FirstMessageNotExistsError:
|
||||||
|
@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||||
except services.errors.message.MessageNotExistsError:
|
except services.errors.message.MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class MessageMoreLikeThisApi(WebApiResource):
|
class MessageMoreLikeThisApi(WebApiResource):
|
||||||
def get(self, app_model, end_user, message_id):
|
def get(self, app_model, end_user, message_id):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
parser.add_argument(
|
||||||
|
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args['response_mode'] == 'streaming'
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
|
@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
user=end_user,
|
user=end_user,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
streaming=streaming
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
|
@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model,
|
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||||
user=end_user,
|
|
||||||
message_id=message_id,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP
|
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message not found")
|
raise NotFound("Message not found")
|
||||||
|
@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
return {'data': questions}
|
return {"data": questions}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(MessageListApi, '/messages')
|
api.add_resource(MessageListApi, "/messages")
|
||||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
|
||||||
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
|
api.add_resource(MessageMoreLikeThisApi, "/messages/<uuid:message_id>/more-like-this")
|
||||||
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
|
api.add_resource(MessageSuggestedQuestionApi, "/messages/<uuid:message_id>/suggested-questions")
|
||||||
|
|
|
@ -15,33 +15,31 @@ from services.feature_service import FeatureService
|
||||||
|
|
||||||
class PassportResource(Resource):
|
class PassportResource(Resource):
|
||||||
"""Base resource for passport."""
|
"""Base resource for passport."""
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
system_features = FeatureService.get_system_features()
|
system_features = FeatureService.get_system_features()
|
||||||
app_code = request.headers.get('X-App-Code')
|
app_code = request.headers.get("X-App-Code")
|
||||||
if app_code is None:
|
if app_code is None:
|
||||||
raise Unauthorized('X-App-Code header is missing.')
|
raise Unauthorized("X-App-Code header is missing.")
|
||||||
|
|
||||||
if system_features.sso_enforced_for_web:
|
if system_features.sso_enforced_for_web:
|
||||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False)
|
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
||||||
if app_web_sso_enabled:
|
if app_web_sso_enabled:
|
||||||
raise WebSSOAuthRequiredError()
|
raise WebSSOAuthRequiredError()
|
||||||
|
|
||||||
# get site from db and check if it is normal
|
# get site from db and check if it is normal
|
||||||
site = db.session.query(Site).filter(
|
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||||
Site.code == app_code,
|
|
||||||
Site.status == 'normal'
|
|
||||||
).first()
|
|
||||||
if not site:
|
if not site:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
# get app from db and check if it is normal and enable_site
|
# get app from db and check if it is normal and enable_site
|
||||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||||
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
end_user = EndUser(
|
end_user = EndUser(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
type='browser',
|
type="browser",
|
||||||
is_anonymous=True,
|
is_anonymous=True,
|
||||||
session_id=generate_session_id(),
|
session_id=generate_session_id(),
|
||||||
)
|
)
|
||||||
|
@ -51,20 +49,20 @@ class PassportResource(Resource):
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"iss": site.app_id,
|
"iss": site.app_id,
|
||||||
'sub': 'Web API Passport',
|
"sub": "Web API Passport",
|
||||||
'app_id': site.app_id,
|
"app_id": site.app_id,
|
||||||
'app_code': app_code,
|
"app_code": app_code,
|
||||||
'end_user_id': end_user.id,
|
"end_user_id": end_user.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
tk = PassportService().issue(payload)
|
tk = PassportService().issue(payload)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'access_token': tk,
|
"access_token": tk,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(PassportResource, '/passport')
|
api.add_resource(PassportResource, "/passport")
|
||||||
|
|
||||||
|
|
||||||
def generate_session_id():
|
def generate_session_id():
|
||||||
|
@ -73,7 +71,6 @@ def generate_session_id():
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
existing_count = db.session.query(EndUser) \
|
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
|
||||||
.filter(EndUser.session_id == session_id).count()
|
|
||||||
if existing_count == 0:
|
if existing_count == 0:
|
||||||
return session_id
|
return session_id
|
||||||
|
|
|
@ -10,67 +10,65 @@ from libs.helper import TimestampField, uuid_value
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
feedback_fields = {
|
feedback_fields = {"rating": fields.String}
|
||||||
'rating': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
message_fields = {
|
message_fields = {
|
||||||
'id': fields.String,
|
"id": fields.String,
|
||||||
'inputs': fields.Raw,
|
"inputs": fields.Raw,
|
||||||
'query': fields.String,
|
"query": fields.String,
|
||||||
'answer': fields.String,
|
"answer": fields.String,
|
||||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
'created_at': TimestampField
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListApi(WebApiResource):
|
class SavedMessageListApi(WebApiResource):
|
||||||
saved_message_infinite_scroll_pagination_fields = {
|
saved_message_infinite_scroll_pagination_fields = {
|
||||||
'limit': fields.Integer,
|
"limit": fields.Integer,
|
||||||
'has_more': fields.Boolean,
|
"has_more": fields.Boolean,
|
||||||
'data': fields.List(fields.Nested(message_fields))
|
"data": fields.List(fields.Nested(message_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model, end_user):
|
def get(self, app_model, end_user):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
|
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
SavedMessageService.save(app_model, end_user, args['message_id'])
|
SavedMessageService.save(app_model, end_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageApi(WebApiResource):
|
class SavedMessageApi(WebApiResource):
|
||||||
def delete(self, app_model, end_user, message_id):
|
def delete(self, app_model, end_user, message_id):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
if app_model.mode != 'completion':
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
SavedMessageService.delete(app_model, end_user, message_id)
|
SavedMessageService.delete(app_model, end_user, message_id)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(SavedMessageListApi, '/saved-messages')
|
api.add_resource(SavedMessageListApi, "/saved-messages")
|
||||||
api.add_resource(SavedMessageApi, '/saved-messages/<uuid:message_id>')
|
api.add_resource(SavedMessageApi, "/saved-messages/<uuid:message_id>")
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user