mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: [backend] vision support (#1510)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
parent
d0e1ea8f06
commit
41d0a8b295
|
@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
|
|||
APP_API_URL=http://127.0.0.1:5001
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
|
@ -70,6 +73,14 @@ MILVUS_USER=root
|
|||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
|
|
|
@ -126,6 +126,7 @@ def register_blueprints(app):
|
|||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
|
@ -155,6 +156,12 @@ def register_blueprints(app):
|
|||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp,
|
||||
allow_headers=['Content-Type'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
|
|
229
api/config.py
229
api/config.py
|
@ -26,6 +26,7 @@ DEFAULTS = {
|
|||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'FILES_URL': '',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
|
@ -57,7 +58,9 @@ DEFAULTS = {
|
|||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
'UPLOAD_FILE_BATCH_LIMIT': 5,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
|
||||
}
|
||||
|
||||
|
||||
|
@ -84,15 +87,9 @@ class Config:
|
|||
"""Application configuration class."""
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.29"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
|
@ -100,70 +97,55 @@ class Config:
|
|||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
|
||||
# The front-end URL prefix of the console web.
|
||||
# used to concatenate some front-end addresses and for CORS configuration use.
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
|
||||
# WebApp API backend Url prefix.
|
||||
# used to declare the back-end URL for the front-end API.
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
|
||||
# WebApp Url prefix.
|
||||
# used to display WebAPP API Base Url to the front-end.
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
|
||||
# Service API Url prefix.
|
||||
# used to display Service API Base Url to the front-end.
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
|
||||
|
||||
# Fallback Url prefix.
|
||||
# Will be deprecated in the future.
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# redis settings
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# storage settings
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# vector store settings, only support weaviate, qdrant
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# check update url
|
||||
self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
|
||||
|
||||
# database settings
|
||||
# ------------------------
|
||||
# Database Configurations.
|
||||
# ------------------------
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in
|
||||
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
|
||||
|
@ -177,14 +159,102 @@ class Config:
|
|||
|
||||
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
|
||||
|
||||
# celery settings
|
||||
# ------------------------
|
||||
# Redis Configurations.
|
||||
# ------------------------
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# ------------------------
|
||||
# Celery worker Configurations.
|
||||
# ------------------------
|
||||
self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
|
||||
self.CELERY_BACKEND = get_env('CELERY_BACKEND')
|
||||
self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
|
||||
|
||||
# hosted provider credentials
|
||||
# ------------------------
|
||||
# File Storage Configurations.
|
||||
# ------------------------
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# ------------------------
|
||||
# Sentry Configurations.
|
||||
# ------------------------
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# ------------------------
|
||||
# Business Configurations.
|
||||
# ------------------------
|
||||
|
||||
# multi model send image format, support base64, url, default is base64
|
||||
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
|
||||
|
||||
# Dataset Configurations.
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# File upload Configurations.
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
|
||||
|
||||
# Moderation in app Configurations.
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
|
||||
# Notion integration setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
|
@ -212,26 +282,6 @@ class Config:
|
|||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# uploading settings
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
|
||||
# moderation settings
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
|
|||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
|
||||
class TestConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.TESTING = True
|
||||
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
|
||||
}
|
||||
|
||||
# use a different database for testing: dify_test
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
|
|
@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
|
@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
|
@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
|
|||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'completion')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
|
|||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'chat')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
|
|||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
|
@ -20,8 +19,6 @@ from models.source import DataSourceBinding
|
|||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
import services
|
||||
from libs.login import login_required
|
||||
|
@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
|
|||
|
||||
from services.file_service import FileService
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
|
@ -30,9 +27,11 @@ class FileApi(Resource):
|
|||
def get(self):
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
||||
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit
|
||||
'batch_count_limit': batch_count_limit,
|
||||
'image_file_size_limit': image_file_size_limit
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
@ -51,7 +50,7 @@ class FileApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
|
@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
|
|
@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
|
|||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
pinned=pinned,
|
||||
exclude_debug_conversation=True
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
|
|||
}
|
||||
for installed_app in installed_apps
|
||||
]
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
|
||||
if app['last_used_at'] is not None else datetime.min))
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
||||
app['last_used_at'] is None,
|
||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
||||
|
||||
return {'installed_apps': installed_apps}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
|
|||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
|
@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
|
|||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
|
@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
|
|||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
|
@ -19,6 +20,7 @@ message_fields = {
|
|||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
|
@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
|
|||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
|
|
|
@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
10
api/controllers/files/__init__.py
Normal file
10
api/controllers/files/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('files', __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
40
api/controllers/files/image_preview.py
Normal file
40
api/controllers/files/image_preview.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
sign = request.args.get('sign')
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {'content': 'Invalid request.'}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
file_id,
|
||||
timestamp,
|
||||
nonce,
|
||||
sign
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
|
@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
|||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from .app import completion, app, conversation, message, audio
|
||||
from .app import completion, app, conversation, message, audio, file
|
||||
|
||||
from .dataset import document, segment, dataset
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
|
@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
|
|||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
|
@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
|
|||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
|
@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
|
|||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
|
|||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
|
@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
|||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
|
42
api/controllers/service_api/app/file.py
Normal file
42
api/controllers/service_api/app/file.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
|
@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
|
|||
from services.message_service import MessageService
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, EndUser
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
feedback_fields = {
|
||||
|
@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
|
|||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
|
@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
|
|
|
@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
|||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport, file
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
|
|||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
|
@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
|
|||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
|
@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
|
|||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
|
|
@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
|||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
|
36
api/controllers/web/file.py
Normal file
36
api/controllers/web/file.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
|
@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
|
|||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
|
@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
|
|||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
|
|
@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
|
|||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
|
@ -18,6 +20,7 @@ message_fields = {
|
|||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ from pydantic import BaseModel
|
|||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
|
|
|
@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
|||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
|
@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory
|
|||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
|
||||
is_override: bool = False, retriever_from: str = 'dev'):
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
|
@ -64,16 +66,21 @@ class Completion:
|
|||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
model_instance=final_model_instance,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
|
||||
prompt_message_files = [file.prompt_message_file for file in files]
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
inputs=inputs,
|
||||
files=prompt_message_files
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
|
@ -95,6 +102,7 @@ class Completion:
|
|||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
|
@ -146,6 +154,7 @@ class Completion:
|
|||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
|
@ -257,6 +266,7 @@ class Completion:
|
|||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
|
@ -266,10 +276,12 @@ class Completion:
|
|||
# get llm prompt
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
||||
mode=mode,
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
|
@ -280,6 +292,7 @@ class Completion:
|
|||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
|
@ -337,7 +350,7 @@ class Completion:
|
|||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict) -> int:
|
||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
|
@ -348,15 +361,16 @@ class Completion:
|
|||
max_tokens = 0
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
prompt_messages = []
|
||||
|
||||
# get prompt without memory and context
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, _ = prompt_transform.get_prompt(
|
||||
mode=mode,
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
|
@ -367,6 +381,7 @@ class Completion:
|
|||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
|
|
|
@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
|
|||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
|
@ -16,13 +17,14 @@ from extensions.ext_database import db
|
|||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource
|
||||
MessageChain, DatasetRetrieverResource, MessageFile
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
||||
auto_generate_name: bool = True):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
|
@ -35,6 +37,7 @@ class ConversationMessageTask:
|
|||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.files = files
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
|
@ -45,6 +48,7 @@ class ConversationMessageTask:
|
|||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
self.auto_generate_name = auto_generate_name
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
|
@ -100,7 +104,7 @@ class ConversationMessageTask:
|
|||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='',
|
||||
name='New conversation',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
|
@ -142,6 +146,19 @@ class ConversationMessageTask:
|
|||
db.session.add(self.message)
|
||||
db.session.commit()
|
||||
|
||||
for file in self.files:
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
|
@ -176,7 +193,8 @@ class ConversationMessageTask:
|
|||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation
|
||||
is_first_message=self.is_new_conversation,
|
||||
auto_generate_name=self.auto_generate_name
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
|
|
0
api/core/file/__init__.py
Normal file
0
api/core/file/__init__.py
Normal file
79
api/core/file/file_obj.py
Normal file
79
api/core/file/file_obj.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str]
|
||||
upload_file_id: Optional[str]
|
||||
file_config: dict
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_file(self) -> PromptMessageFile:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_config.get('image')
|
||||
|
||||
return ImagePromptMessageFile(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageFile.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
return UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
)
|
||||
|
||||
return None
|
180
api/core/file/message_file_parser.py
Normal file
180
api/core/file/message_file_parser.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
from typing import List, Union, Optional, Dict
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
file_upload_config = app_model_config.file_upload_dict
|
||||
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError('Invalid file format, must be dict')
|
||||
if not file.get('type'):
|
||||
raise ValueError('Missing file type')
|
||||
FileType.value_of(file.get('type'))
|
||||
if not file.get('transfer_method'):
|
||||
raise ValueError('Missing file transfer method')
|
||||
FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get('url'):
|
||||
raise ValueError('Missing file url')
|
||||
if not file.get('url').startswith('http'):
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_upload_config.get('image')
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config['enabled']:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config['number_limits']:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config['transfer_methods']:
|
||||
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f'Invalid file type: {file_obj.type}')
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError('Invalid upload file')
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
return FileObj(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
else:
|
||||
return FileObj(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id or None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
79
api/core/file/upload_file_parser.py
Normal file
79
api/core/file/upload_file_parser.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in SUPPORT_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f'File not found: {upload_file.key}')
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= 300 # expired after 5 minutes
|
|
@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
|||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
def generate_conversation_name(cls, tenant_id: str, query):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
|
@ -40,8 +40,12 @@ class LLMGenerator:
|
|||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
name = answer.strip()
|
||||
|
||||
return answer.strip()
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, List, Dict
|
|||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
|
@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
app_model = self.conversation.app
|
||||
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
|
@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
|
||||
chat_messages.append(PromptMessage(
|
||||
content=message.query,
|
||||
type=MessageType.USER,
|
||||
files=prompt_message_files
|
||||
))
|
||||
else:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import enum
|
||||
from typing import Any, cast, Union, List, Dict
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
|
@ -18,17 +19,53 @@ class MessageType(enum.Enum):
|
|||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.USER
|
||||
content: str = ''
|
||||
files: list[PromptMessageFile] = []
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.USER:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
if not message.files:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
else:
|
||||
lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
additional_kwargs = {}
|
||||
if message.function_call:
|
||||
|
@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
|
|||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
if isinstance(message, LCHumanMessageWithFiles):
|
||||
prompt_messages.append(PromptMessage(
|
||||
content=message.content,
|
||||
type=MessageType.USER,
|
||||
files=message.files
|
||||
))
|
||||
else:
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
elif isinstance(message, AIMessage):
|
||||
message_kwargs = {
|
||||
'content': message.content,
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from openai import api_requestor
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
|
|
|
@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
|
@ -16,32 +16,59 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
|||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class AppMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def get_prompt(self, mode: str,
|
||||
pre_prompt: str, inputs: dict,
|
||||
def get_prompt(self,
|
||||
app_mode: str,
|
||||
app_model_config: AppModelConfig,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
model_mode = app_model_config.model_dict['mode']
|
||||
|
||||
app_mode_enum = AppMode(app_mode)
|
||||
model_mode_enum = ModelMode(model_mode)
|
||||
|
||||
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
|
||||
|
||||
if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
|
||||
stops = None
|
||||
|
||||
prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
|
||||
query, context, memory,
|
||||
model_instance, files)
|
||||
else:
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
|
||||
memory,
|
||||
model_instance, files)
|
||||
return prompt_messages, stops
|
||||
|
||||
def get_advanced_prompt(self,
|
||||
app_mode: str,
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
def get_advanced_prompt(self,
|
||||
app_mode: str,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
model_mode = app_model_config.model_dict['mode']
|
||||
|
||||
app_mode_enum = AppMode(app_mode)
|
||||
|
@ -51,15 +78,20 @@ class PromptTransform:
|
|||
|
||||
if app_mode_enum == AppMode.CHAT:
|
||||
if model_mode_enum == ModelMode.COMPLETION:
|
||||
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
|
||||
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
|
||||
files, context, memory,
|
||||
model_instance)
|
||||
elif model_mode_enum == ModelMode.CHAT:
|
||||
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
|
||||
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
|
||||
context, memory, model_instance)
|
||||
elif app_mode_enum == AppMode.COMPLETION:
|
||||
if model_mode_enum == ModelMode.CHAT:
|
||||
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
|
||||
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
|
||||
files, context)
|
||||
elif model_mode_enum == ModelMode.COMPLETION:
|
||||
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
|
||||
|
||||
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
|
||||
files, context)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
||||
|
@ -71,7 +103,7 @@ class PromptTransform:
|
|||
return external_context[memory_key]
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory.return_messages = True
|
||||
|
@ -79,7 +111,7 @@ class PromptTransform:
|
|||
external_context = memory.load_memory_variables({})
|
||||
memory.return_messages = False
|
||||
return to_prompt_messages(external_context[memory_key])
|
||||
|
||||
|
||||
def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
|
||||
# baichuan
|
||||
if isinstance(model_instance, BaichuanModel):
|
||||
|
@ -94,13 +126,13 @@ class PromptTransform:
|
|||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
|
||||
def _prompt_file_name_for_baichuan(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
|
||||
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(
|
||||
|
@ -111,12 +143,53 @@ class PromptTransform:
|
|||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
|
||||
|
||||
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM,
|
||||
files: List[PromptMessageFile]) -> List[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
{'context': context}
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += pre_prompt_content
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
|
||||
|
||||
self._append_chat_histories(memory, prompt_messages, model_instance)
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM,
|
||||
files: List[PromptMessageFile]) -> List[PromptMessage]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
|
@ -175,16 +248,12 @@ class PromptTransform:
|
|||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
return [PromptMessage(content=prompt, files=files)]
|
||||
|
||||
return prompt, stops
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
|
||||
if '#context#' in prompt_template.variable_keys:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
prompt_inputs['#context#'] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
|
||||
|
@ -195,17 +264,18 @@ class PromptTransform:
|
|||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
|
||||
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
|
||||
prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
|
||||
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
|
||||
prompt_template: PromptTemplateParser, prompt_inputs: dict,
|
||||
model_instance: BaseLLM) -> None:
|
||||
if '#histories#' in prompt_template.variable_keys:
|
||||
if memory:
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=raw_prompt,
|
||||
inputs={ '#histories#': '', **prompt_inputs }
|
||||
inputs={'#histories#': '', **prompt_inputs}
|
||||
)
|
||||
|
||||
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
|
||||
|
||||
|
||||
memory.human_prefix = conversation_histories_role['user_prefix']
|
||||
memory.ai_prefix = conversation_histories_role['assistant_prefix']
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
@ -213,7 +283,8 @@ class PromptTransform:
|
|||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
|
||||
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
|
||||
model_instance: BaseLLM) -> None:
|
||||
if memory:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
|
||||
|
||||
|
@ -242,19 +313,19 @@ class PromptTransform:
|
|||
return prompt
|
||||
|
||||
def _get_chat_app_completion_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
|
||||
|
||||
prompt_messages = []
|
||||
prompt = ''
|
||||
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
|
@ -262,28 +333,29 @@ class PromptTransform:
|
|||
|
||||
self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
|
||||
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
|
||||
model_instance)
|
||||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_chat_app_chat_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item['text']
|
||||
prompt = ''
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
@ -292,23 +364,23 @@ class PromptTransform:
|
|||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
|
||||
|
||||
self._append_chat_histories(memory, prompt_messages, model_instance)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_app_completion_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
|
||||
prompt_messages = []
|
||||
prompt = ''
|
||||
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
|
@ -316,21 +388,21 @@ class PromptTransform:
|
|||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
|
||||
prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_app_chat_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item['text']
|
||||
prompt = ''
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
@ -339,6 +411,11 @@ class PromptTransform:
|
|||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
return prompt_messages
|
||||
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
|
||||
|
||||
for prompt_message in prompt_messages[::-1]:
|
||||
if prompt_message.type == MessageType.USER:
|
||||
prompt_message.files = files
|
||||
break
|
||||
|
||||
return prompt_messages
|
||||
|
|
104
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
104
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
|
||||
from typing import Dict, Any, Optional, Union, Tuple
|
||||
from typing import Dict, Any, Optional, Union, Tuple, List, cast
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
|
@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
|
|||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model, encoding = self._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import logging
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
|
@ -10,8 +8,9 @@ def handle(sender, **kwargs):
|
|||
message = sender
|
||||
conversation = kwargs.get('conversation')
|
||||
is_first_message = kwargs.get('is_first_message')
|
||||
auto_generate_name = kwargs.get('auto_generate_name', True)
|
||||
|
||||
if is_first_message:
|
||||
if auto_generate_name and is_first_message:
|
||||
if conversation.mode == 'chat':
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
|
@ -19,14 +18,9 @@ def handle(sender, **kwargs):
|
|||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
|
||||
conversation.name = name
|
||||
except:
|
||||
conversation.name = 'New conversation'
|
||||
pass
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
from contextlib import closing
|
||||
from typing import Union, Generator
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
@ -45,7 +46,13 @@ class Storage:
|
|||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self, filename):
|
||||
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
|
||||
if stream:
|
||||
return self.load_stream(filename)
|
||||
else:
|
||||
return self.load_once(filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
|
@ -69,6 +76,34 @@ class Storage:
|
|||
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
for chunk in response['Body'].iter_chunks():
|
||||
yield chunk
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
|
|
|
@ -32,7 +32,8 @@ model_config_fields = {
|
|||
'prompt_type': fields.String,
|
||||
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
|
||||
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
|
||||
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
|
||||
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
|
||||
'file_upload': fields.Raw(attribute='file_upload_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
|
@ -140,4 +141,4 @@ app_site_fields = {
|
|||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,12 @@ annotation_fields = {
|
|||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
|
@ -43,7 +49,8 @@ message_detail_fields = {
|
|||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
|
@ -111,11 +118,6 @@ conversation_message_detail_fields = {
|
|||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
|
@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
|
|||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,8 @@ from libs.helper import TimestampField
|
|||
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
'batch_count_limit': fields.Integer,
|
||||
'image_file_size_limit': fields.Integer,
|
||||
}
|
||||
|
||||
file_fields = {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
|
@ -31,6 +32,7 @@ message_fields = {
|
|||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
|
59
api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
Normal file
59
api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
"""add gpt4v supports
|
||||
|
||||
Revision ID: 8fe468ba0ca5
|
||||
Revises: a9836e3baeee
|
||||
Create Date: 2023-11-09 11:39:00.006432
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8fe468ba0ca5'
|
||||
down_revision = 'a9836e3baeee'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('message_files',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('message_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('transfer_method', sa.String(length=255), nullable=False),
|
||||
sa.Column('url', sa.Text(), nullable=True),
|
||||
sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='message_file_pkey')
|
||||
)
|
||||
with op.batch_alter_table('message_files', schema=None) as batch_op:
|
||||
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
|
||||
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.drop_column('created_by_role')
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('file_upload')
|
||||
|
||||
with op.batch_alter_table('message_files', schema=None) as batch_op:
|
||||
batch_op.drop_index('message_file_message_idx')
|
||||
batch_op.drop_index('message_file_created_by_idx')
|
||||
|
||||
op.drop_table('message_files')
|
||||
# ### end Alembic commands ###
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
from json import JSONDecodeError
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from libs.helper import generate_string
|
||||
from extensions.ext_database import db
|
||||
from .account import Account, Tenant
|
||||
|
@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
|
|||
completion_prompt_config = db.Column(db.Text)
|
||||
dataset_configs = db.Column(db.Text)
|
||||
external_data_tools = db.Column(db.Text)
|
||||
file_upload = db.Column(db.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
|
|||
def dataset_configs_dict(self) -> dict:
|
||||
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"provider": "",
|
||||
|
@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
|
|||
"prompt_type": self.prompt_type,
|
||||
"chat_prompt_config": self.chat_prompt_config_dict,
|
||||
"completion_prompt_config": self.completion_prompt_config_dict,
|
||||
"dataset_configs": self.dataset_configs_dict
|
||||
"dataset_configs": self.dataset_configs_dict,
|
||||
"file_upload": self.file_upload_dict
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: dict):
|
||||
|
@ -213,6 +219,8 @@ class AppModelConfig(db.Model):
|
|||
if model_config.get('completion_prompt_config') else None
|
||||
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
|
||||
if model_config.get('dataset_configs') else None
|
||||
self.file_upload = json.dumps(model_config.get('file_upload')) \
|
||||
if model_config.get('file_upload') else None
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
|
@ -238,7 +246,8 @@ class AppModelConfig(db.Model):
|
|||
prompt_type=self.prompt_type,
|
||||
chat_prompt_config=self.chat_prompt_config,
|
||||
completion_prompt_config=self.completion_prompt_config,
|
||||
dataset_configs=self.dataset_configs
|
||||
dataset_configs=self.dataset_configs,
|
||||
file_upload=self.file_upload
|
||||
)
|
||||
|
||||
return new_app_model_config
|
||||
|
@ -512,6 +521,37 @@ class Message(db.Model):
|
|||
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
|
||||
.order_by(DatasetRetrieverResource.position.asc()).all()
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
message_files = self.message_files
|
||||
|
||||
files = []
|
||||
for message_file in message_files:
|
||||
url = message_file.url
|
||||
if message_file.type == 'image':
|
||||
if message_file.transfer_method == 'local_file':
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == message_file.upload_file_id
|
||||
).first())
|
||||
|
||||
url = UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=True
|
||||
)
|
||||
|
||||
files.append({
|
||||
'id': message_file.id,
|
||||
'type': message_file.type,
|
||||
'url': url
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
|
||||
class MessageFeedback(db.Model):
|
||||
__tablename__ = 'message_feedbacks'
|
||||
|
@ -540,6 +580,25 @@ class MessageFeedback(db.Model):
|
|||
return account
|
||||
|
||||
|
||||
class MessageFile(db.Model):
|
||||
__tablename__ = 'message_files'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
|
||||
db.Index('message_file_message_idx', 'message_id'),
|
||||
db.Index('message_file_created_by_idx', 'created_by')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
upload_file_id = db.Column(UUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
__tablename__ = 'message_annotations'
|
||||
__table_args__ = (
|
||||
|
@ -683,6 +742,7 @@ class UploadFile(db.Model):
|
|||
size = db.Column(db.Integer, nullable=False)
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
|
@ -783,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
|
|||
retriever_from = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
|
|
@ -315,6 +315,9 @@ class AppModelConfigService:
|
|||
# moderation validation
|
||||
cls.is_moderation_valid(tenant_id, config)
|
||||
|
||||
# file upload validation
|
||||
cls.is_file_upload_valid(config)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
|
@ -338,7 +341,8 @@ class AppModelConfigService:
|
|||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"]
|
||||
"dataset_configs": config["dataset_configs"],
|
||||
"file_upload": config["file_upload"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
|
@ -371,6 +375,34 @@ class AppModelConfigService:
|
|||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_file_upload_valid(cls, config: dict):
|
||||
if 'file_upload' not in config or not config["file_upload"]:
|
||||
config["file_upload"] = {}
|
||||
|
||||
if not isinstance(config["file_upload"], dict):
|
||||
raise ValueError("file_upload must be of dict type")
|
||||
|
||||
# check image config
|
||||
if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
@classmethod
|
||||
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
|
||||
if 'external_data_tools' not in config or not config["external_data_tools"]:
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any, Optional
|
||||
from typing import Generator, Union, Any, Optional, List
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
|
@ -12,9 +12,11 @@ from sqlalchemy import and_
|
|||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.models.entity.message import PromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
|
@ -35,6 +37,9 @@ class CompletionService:
|
|||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
auto_generate_name = args['auto_generate_name'] \
|
||||
if 'auto_generate_name' in args else True
|
||||
|
||||
if app_model.mode != 'completion' and not query:
|
||||
raise ValueError('query is required')
|
||||
|
@ -132,6 +137,14 @@ class CompletionService:
|
|||
# clean input by app_model_config form rules
|
||||
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
app_model_config,
|
||||
user
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
|
@ -146,17 +159,20 @@ class CompletionService:
|
|||
'app_model_config': app_model_config.copy(),
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
|
@ -172,10 +188,12 @@ class CompletionService:
|
|||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, detached_user: Union[Account, EndUser],
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev'):
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
|
@ -195,10 +213,12 @@ class CompletionService:
|
|||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
files=files,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
|
@ -215,7 +235,8 @@ class CompletionService:
|
|||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
|
||||
generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
|
@ -274,6 +295,12 @@ class CompletionService:
|
|||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
|
@ -288,11 +315,13 @@ class CompletionService:
|
|||
'app_model_config': app_model_config.copy(),
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from
|
||||
'retriever_from': retriever_from,
|
||||
'auto_generate_name': False
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
@ -388,7 +417,8 @@ class CompletionService:
|
|||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
from typing import Union, Optional
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, App, EndUser
|
||||
from models.model import Conversation, App, EndUser, Message
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
|
||||
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
|
@ -29,6 +32,9 @@ class ConversationService:
|
|||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if exclude_debug_conversation:
|
||||
base_query = base_query.filter(Conversation.override_model_configs == None)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
|
@ -63,10 +69,36 @@ class ConversationService:
|
|||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account | EndUser]], name: str):
|
||||
user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
conversation.name = name
|
||||
if auto_generate:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
|
||||
conversation.name = name
|
||||
except:
|
||||
pass
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
|
|
@ -1,46 +1,62 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.account import Account
|
||||
from models.model import UploadFile, EndUser
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
|
||||
'jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage) -> UploadFile:
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# read file content
|
||||
file_content = file.read()
|
||||
|
||||
# get file size
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if extension.lower() in IMAGE_EXTENSIONS:
|
||||
file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
else:
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
if isinstance(user, Account):
|
||||
current_tenant_id = user.current_tenant_id
|
||||
else:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
|
||||
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
@ -48,14 +64,15 @@ class FileService:
|
|||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
|
||||
created_by=user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
|
@ -99,12 +116,6 @@ class FileService:
|
|||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
# get file storage key
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
@ -121,3 +132,25 @@ class FileService:
|
|||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
|
||||
result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
|
|
@ -11,7 +11,8 @@ from services.conversation_service import ConversationService
|
|||
class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
last_id: Optional[str], limit: int, pinned: Optional[bool] = None,
|
||||
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
|
@ -32,7 +33,8 @@ class WebConversationService:
|
|||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids
|
||||
exclude_ids=exclude_ids,
|
||||
exclude_debug_conversation=exclude_debug_conversation
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
|||
from langchain.schema import Generation, ChatGeneration, AIMessage
|
||||
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from models.provider import Provider, ProviderType
|
||||
|
@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
|
|||
assert rst == 22
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_vision_chat_get_num_tokens(mock_decrypt):
|
||||
openai_model = get_mock_openai_model('gpt-4-vision-preview')
|
||||
messages = [
|
||||
PromptMessage(content='What’s in first image?', files=[
|
||||
ImageMessageFile(
|
||||
data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
|
||||
])
|
||||
]
|
||||
rst = openai_model.get_num_tokens(messages)
|
||||
assert rst == 77
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
|
|||
messages,
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert (len(rst.content) > 0)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_vision_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_openai_model('gpt-4-vision-preview')
|
||||
messages = [
|
||||
PromptMessage(content='What’s in first image?', files=[
|
||||
ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
|
||||
])
|
||||
]
|
||||
rst = openai_model.run(
|
||||
messages,
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
|
|
|
@ -19,18 +19,22 @@ services:
|
|||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
CONSOLE_API_URL: ''
|
||||
# The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
|
||||
# The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
|
||||
# different from console domain.
|
||||
# example: http://api.dify.ai
|
||||
SERVICE_API_URL: ''
|
||||
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_API_URL: ''
|
||||
# The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_WEB_URL: ''
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
FILES_URL: ''
|
||||
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
|
||||
MIGRATION_ENABLED: 'true'
|
||||
# The configurations of postgres database connection.
|
||||
|
|
|
@ -17,6 +17,11 @@ server {
|
|||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /files {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location / {
|
||||
proxy_pass http://web:3000;
|
||||
include proxy.conf;
|
||||
|
|
Loading…
Reference in New Issue
Block a user