diff --git a/api/.env.example b/api/.env.example index f50d981a7d..8a451dec17 100644 --- a/api/.env.example +++ b/api/.env.example @@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001 APP_API_URL=http://127.0.0.1:5001 APP_WEB_URL=http://127.0.0.1:3000 +# Files URL +FILES_URL=http://127.0.0.1:5001 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 @@ -70,6 +73,14 @@ MILVUS_USER=root MILVUS_PASSWORD=Milvus MILVUS_SECURE=false +# Upload configuration +UPLOAD_FILE_SIZE_LIMIT=15 +UPLOAD_FILE_BATCH_LIMIT=5 +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 + +# Model Configuration +MULTIMODAL_SEND_IMAGE_FORMAT=base64 + # Mail configuration, support: resend MAIL_TYPE= MAIL_DEFAULT_SEND_FROM=no-reply diff --git a/api/app.py b/api/app.py index bc0d25224e..0a43813201 100644 --- a/api/app.py +++ b/api/app.py @@ -126,6 +126,7 @@ def register_blueprints(app): from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp CORS(service_api_bp, allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], @@ -155,6 +156,12 @@ def register_blueprints(app): app.register_blueprint(console_app_bp) + CORS(files_bp, + allow_headers=['Content-Type'], + methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] + ) + app.register_blueprint(files_bp) + # create app app = create_app() diff --git a/api/config.py b/api/config.py index efc1bacc59..0e2385fdbe 100644 --- a/api/config.py +++ b/api/config.py @@ -26,6 +26,7 @@ DEFAULTS = { 'SERVICE_API_URL': 'https://api.dify.ai', 'APP_WEB_URL': 'https://udify.app', 'APP_API_URL': 'https://udify.app', + 'FILES_URL': '', 'STORAGE_TYPE': 'local', 'STORAGE_LOCAL_PATH': 'storage', 'CHECK_UPDATE_URL': 'https://updates.dify.ai', @@ -57,7 +58,9 @@ DEFAULTS = { 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, 'UPLOAD_FILE_BATCH_LIMIT': 5, - 'OUTPUT_MODERATION_BUFFER_SIZE': 300 + 'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10, + 'OUTPUT_MODERATION_BUFFER_SIZE': 300, + 'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64' } @@ -84,15 +87,9 @@ class Config: """Application configuration class.""" def __init__(self): - # app settings - self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL') - self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL') - self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL') - self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL') - self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL') - self.CONSOLE_URL = get_env('CONSOLE_URL') - self.API_URL = get_env('API_URL') - self.APP_URL = get_env('APP_URL') + # ------------------------ + # General Configurations. + # ------------------------ self.CURRENT_VERSION = "0.3.29" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" @@ -100,70 +97,55 @@ class Config: self.TESTING = False self.LOG_LEVEL = get_env('LOG_LEVEL') + # The backend URL prefix of the console API. + # used to concatenate the login authorization callback or notion integration callback. + self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL') + + # The front-end URL prefix of the console web. + # used to concatenate some front-end addresses and for CORS configuration use. + self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL') + + # WebApp API backend Url prefix. + # used to declare the back-end URL for the front-end API. + self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL') + + # WebApp Url prefix. + # used to display WebAPP API Base Url to the front-end. + self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL') + + # Service API Url prefix. + # used to display Service API Base Url to the front-end. + self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL') + + # File preview or download Url prefix. + # used to display File preview or download Url to the front-end or as Multi-model inputs; + # Url is signed and has expiration time. + self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL + + # Fallback Url prefix. + # Will be deprecated in the future. + self.CONSOLE_URL = get_env('CONSOLE_URL') + self.API_URL = get_env('API_URL') + self.APP_URL = get_env('APP_URL') + # Your App secret key will be used for securely signing the session cookie # Make sure you are changing this key for your deployment with a strong key. # You can generate a strong key using `openssl rand -base64 42`. # Alternatively you can set it with `SECRET_KEY` environment variable. self.SECRET_KEY = get_env('SECRET_KEY') - # redis settings - self.REDIS_HOST = get_env('REDIS_HOST') - self.REDIS_PORT = get_env('REDIS_PORT') - self.REDIS_USERNAME = get_env('REDIS_USERNAME') - self.REDIS_PASSWORD = get_env('REDIS_PASSWORD') - self.REDIS_DB = get_env('REDIS_DB') - self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') - - # storage settings - self.STORAGE_TYPE = get_env('STORAGE_TYPE') - self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') - self.S3_ENDPOINT = get_env('S3_ENDPOINT') - self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') - self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') - self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') - self.S3_REGION = get_env('S3_REGION') - - # vector store settings, only support weaviate, qdrant - self.VECTOR_STORE = get_env('VECTOR_STORE') - - # weaviate settings - self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') - self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY') - self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED') - self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE')) - - # qdrant settings - self.QDRANT_URL = get_env('QDRANT_URL') - self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') - - # milvus setting - self.MILVUS_HOST = get_env('MILVUS_HOST') - self.MILVUS_PORT = get_env('MILVUS_PORT') - self.MILVUS_USER = get_env('MILVUS_USER') - self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') - self.MILVUS_SECURE = get_env('MILVUS_SECURE') - - # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'WEB_API_CORS_ALLOW_ORIGINS', '*') - # mail settings - self.MAIL_TYPE = get_env('MAIL_TYPE') - self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM') - self.RESEND_API_KEY = get_env('RESEND_API_KEY') - - # sentry settings - self.SENTRY_DSN = get_env('SENTRY_DSN') - self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE')) - self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE')) - # check update url self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL') - # database settings + # ------------------------ + # Database Configurations. + # ------------------------ db_credentials = { key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE'] @@ -177,14 +159,102 @@ class Config: self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') - # celery settings + # ------------------------ + # Redis Configurations. + # ------------------------ + self.REDIS_HOST = get_env('REDIS_HOST') + self.REDIS_PORT = get_env('REDIS_PORT') + self.REDIS_USERNAME = get_env('REDIS_USERNAME') + self.REDIS_PASSWORD = get_env('REDIS_PASSWORD') + self.REDIS_DB = get_env('REDIS_DB') + self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') + + # ------------------------ + # Celery worker Configurations. + # ------------------------ self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL') self.CELERY_BACKEND = get_env('CELERY_BACKEND') self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') - # hosted provider credentials + # ------------------------ + # File Storage Configurations. + # ------------------------ + self.STORAGE_TYPE = get_env('STORAGE_TYPE') + self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') + self.S3_ENDPOINT = get_env('S3_ENDPOINT') + self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') + self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') + self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') + self.S3_REGION = get_env('S3_REGION') + + # ------------------------ + # Vector Store Configurations. + # Currently, only support: qdrant, milvus, zilliz, weaviate + # ------------------------ + self.VECTOR_STORE = get_env('VECTOR_STORE') + + # qdrant settings + self.QDRANT_URL = get_env('QDRANT_URL') + self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') + + # milvus / zilliz setting + self.MILVUS_HOST = get_env('MILVUS_HOST') + self.MILVUS_PORT = get_env('MILVUS_PORT') + self.MILVUS_USER = get_env('MILVUS_USER') + self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') + self.MILVUS_SECURE = get_env('MILVUS_SECURE') + + # weaviate settings + self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') + self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY') + self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED') + self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE')) + + # ------------------------ + # Mail Configurations. + # ------------------------ + self.MAIL_TYPE = get_env('MAIL_TYPE') + self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM') + self.RESEND_API_KEY = get_env('RESEND_API_KEY') + + # ------------------------ + # Sentry Configurations. + # ------------------------ + self.SENTRY_DSN = get_env('SENTRY_DSN') + self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE')) + self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE')) + + # ------------------------ + # Business Configurations. + # ------------------------ + + # multi model send image format, support base64, url, default is base64 + self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT') + + # Dataset Configurations. + self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT') + self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING') + + # File upload Configurations. + self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) + self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) + self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT')) + + # Moderation in app Configurations. + self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE')) + + # Notion integration setting + self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') + self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') + self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') + self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') + self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') + + # ------------------------ + # Platform Configurations. + # ------------------------ self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED') self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') @@ -212,26 +282,6 @@ class Config: self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') - self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') - self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') - - # notion import setting - self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') - self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') - self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') - self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') - self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') - - self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT') - self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING') - - # uploading settings - self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) - self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) - - # moderation settings - self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE')) - class CloudEditionConfig(Config): @@ -246,18 +296,5 @@ class CloudEditionConfig(Config): self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') - -class TestConfig(Config): - - def __init__(self): - super().__init__() - - self.EDITION = "SELF_HOSTED" - self.TESTING = True - - db_credentials = { - key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT'] - } - - # use a different database for testing: dify_test - self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test" + self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') + self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 1da7bd8f2c..61a4f1b69d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -40,12 +40,14 @@ class CompletionMessageApi(Resource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') args = parser.parse_args() streaming = args['response_mode'] != 'blocking' + args['auto_generate_name'] = False account = flask_login.current_user @@ -113,6 +115,7 @@ class ChatMessageApi(Resource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') @@ -120,6 +123,7 @@ class ChatMessageApi(Resource): args = parser.parse_args() streaming = args['response_mode'] != 'blocking' + args['auto_generate_name'] = False account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index af4323324e..66abf63e65 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource): conversation_id = str(conversation_id) return _get_conversation(app_id, conversation_id, 'completion') - + @setup_required @login_required @account_initialization_required @@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource): conversation_id = str(conversation_id) return _get_conversation(app_id, conversation_id, 'chat') - + @setup_required @login_required @account_initialization_required @@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource): return {'result': 'success'}, 204 - - api.add_resource(CompletionConversationApi, '/apps//completion-conversations') api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') api.add_resource(ChatConversationApi, '/apps//chat-conversations') diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 1631cff1c6..9422526556 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,7 +1,6 @@ import datetime import json -from cachetools import TTLCache from flask import request from flask_login import current_user from libs.login import login_required @@ -20,8 +19,6 @@ from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task -cache = TTLCache(maxsize=None, ttl=30) - class DataSourceApi(Resource): diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index de2e347b18..c6b7391375 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,5 +1,5 @@ -from cachetools import TTLCache from flask import request, current_app +from flask_login import current_user import services from libs.login import login_required @@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields from services.file_service import FileService -cache = TTLCache(maxsize=None, ttl=30) - -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] PREVIEW_WORDS_LIMIT = 3000 @@ -30,9 +27,11 @@ class FileApi(Resource): def get(self): file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT") + image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") return { 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit + 'batch_count_limit': batch_count_limit, + 'image_file_size_limit': image_file_size_limit }, 200 @setup_required @@ -51,7 +50,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file) + upload_file = FileService.upload_file(file, current_user) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index bdf1f3b907..45ef582b6b 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- import json import logging +from datetime import datetime from typing import Generator, Union from flask import Response, stream_with_context @@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.conversation_message_task import PubHandler from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from extensions.ext_database import db from libs.helper import uuid_value from services.completion_service import CompletionService @@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False + + installed_app.last_used_at = datetime.utcnow() + db.session.commit() try: response = CompletionService.completion( @@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False + + installed_app.last_used_at = datetime.utcnow() + db.session.commit() try: response = CompletionService.completion( diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index aa56437f33..f6d4d84a90 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource): user=current_user, last_id=args['last_id'], limit=args['limit'], - pinned=pinned + pinned=pinned, + exclude_debug_conversation=True ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, current_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + current_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d36d02828e..d7ee991663 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource): } for installed_app in installed_apps ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at'] - if app['last_used_at'] is not None else datetime.min)) + installed_apps.sort(key=lambda app: (-app['is_pinned'], + app['last_used_at'] is None, + -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) return {'installed_apps': installed_apps} diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 9514834a2b..63066f9f56 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,5 +1,6 @@ # -*- coding:utf-8 -*- from flask_restful import marshal_with, fields +from flask import current_app from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource @@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource): 'options': fields.List(fields.String) } + system_parameters_fields = { + 'image_file_size_limit': fields.String + } + parameters_fields = { 'opening_statement': fields.String, 'suggested_questions': fields.Raw, @@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw + 'sensitive_word_avoidance': fields.Raw, + 'file_upload': fields.Raw, + 'system_parameters': fields.Nested(system_parameters_fields) } @marshal_with(parameters_fields) @@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource): 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, + 'file_upload': app_model_config.file_upload_dict, + 'system_parameters': { + 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') + } } diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 3f9bc63096..7977d0ebaf 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource from libs.helper import uuid_value, TimestampField from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService +from fields.conversation_fields import message_file_fields feedback_fields = { 'rating': fields.String @@ -19,6 +20,7 @@ message_fields = { 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'created_at': TimestampField } diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index 61ba50325e..9282b1cba2 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource): parser = reqparse.RequestParser() parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('provider', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json') @@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource): del args['model'] del args['tools'] + args['auto_generate_name'] = False + try: response = CompletionService.completion( app_model=app_model, diff --git a/api/controllers/console/universal_chat/conversation.py b/api/controllers/console/universal_chat/conversation.py index c0782cb81a..a85e392c25 100644 --- a/api/controllers/console/universal_chat/conversation.py +++ b/api/controllers/console/universal_chat/conversation.py @@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, current_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + current_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py new file mode 100644 index 0000000000..83374497b7 --- /dev/null +++ b/api/controllers/files/__init__.py @@ -0,0 +1,10 @@ +# -*- coding:utf-8 -*- +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint('files', __name__) +api = ExternalApi(bp) + + +from . import image_preview diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py new file mode 100644 index 0000000000..2f989ddcde --- /dev/null +++ b/api/controllers/files/image_preview.py @@ -0,0 +1,40 @@ +from flask import request, Response +from flask_restful import Resource + +import services +from controllers.files import api +from libs.exception import BaseHTTPException +from services.file_service import FileService + + +class ImagePreviewApi(Resource): + def get(self, file_id): + file_id = str(file_id) + + timestamp = request.args.get('timestamp') + nonce = request.args.get('nonce') + sign = request.args.get('sign') + + if not timestamp or not nonce or not sign: + return {'content': 'Invalid request.'}, 400 + + try: + generator, mimetype = FileService.get_image_preview( + file_id, + timestamp, + nonce, + sign + ) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return Response(generator, mimetype=mimetype) + + +api.add_resource(ImagePreviewApi, '/files//image-preview') + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 0c3ec30072..b0ee669362 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1') api = ExternalApi(bp) -from .app import completion, app, conversation, message, audio +from .app import completion, app, conversation, message, audio, file from .dataset import document, segment, dataset diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index cca4c1addb..f38f60cf5b 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,5 +1,6 @@ # -*- coding:utf-8 -*- from flask_restful import fields, marshal_with +from flask import current_app from controllers.service_api import api from controllers.service_api.wraps import AppApiResource @@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource): 'options': fields.List(fields.String) } + system_parameters_fields = { + 'image_file_size_limit': fields.String + } + parameters_fields = { 'opening_statement': fields.String, 'suggested_questions': fields.Raw, @@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw + 'sensitive_word_avoidance': fields.Raw, + 'file_upload': fields.Raw, + 'system_parameters': fields.Nested(system_parameters_fields) } @marshal_with(parameters_fields) @@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource): 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, + 'file_upload': app_model_config.file_upload_dict, + 'system_parameters': { + 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') + } } diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5ab8a7d116..e72164022d 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,6 +28,7 @@ class CompletionApi(AppApiResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('user', type=str, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') @@ -39,13 +40,15 @@ class CompletionApi(AppApiResource): if end_user is None and args['user'] is not None: end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + args['auto_generate_name'] = False + try: response = CompletionService.completion( app_model=app_model, user=end_user, args=args, from_source='api', - streaming=streaming + streaming=streaming, ) return compact_response(response) @@ -90,10 +93,12 @@ class ChatApi(AppApiResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('user', type=str, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json') args = parser.parse_args() diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index e111ea2ebc..2fdddef8a1 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') parser.add_argument('user', type=str, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() if end_user is None and args['user'] is not None: end_user = create_or_update_end_user_for_user_id(app_model, args['user']) try: - return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + end_user, + args['name'], + args['auto_generate'] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index f509dc1b48..56beb56949 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException): description = "Provider not support speech to text." code = 400 + +class NoFileUploadedError(BaseHTTPException): + error_code = 'no_file_uploaded' + description = "Please upload your file." + code = 400 + + +class TooManyFilesError(BaseHTTPException): + error_code = 'too_many_files' + description = "Only one file is allowed." + code = 400 + + +class FileTooLargeError(BaseHTTPException): + error_code = 'file_too_large' + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py new file mode 100644 index 0000000000..f928e083a5 --- /dev/null +++ b/api/controllers/service_api/app/file.py @@ -0,0 +1,42 @@ +from flask import request +from flask_restful import marshal_with + +from controllers.service_api import api +from controllers.service_api.wraps import AppApiResource +from controllers.service_api.app import create_or_update_end_user_for_user_id +from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ + UnsupportedFileTypeError +import services +from services.file_service import FileService +from fields.file_fields import file_fields + + +class FileApi(AppApiResource): + + @marshal_with(file_fields) + def post(self, app_model, end_user): + + file = request.files['file'] + user_args = request.form.get('user') + + if end_user is None and user_args is not None: + end_user = create_or_update_end_user_for_user_id(app_model, user_args) + + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + try: + upload_file = FileService.upload_file(file, end_user) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 + + +api.add_resource(FileApi, '/files/upload') diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 614af37653..16106c340e 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value from services.message_service import MessageService from extensions.ext_database import db from models.model import Message, EndUser - +from fields.conversation_fields import message_file_fields class MessageListApi(AppApiResource): feedback_fields = { @@ -43,6 +43,7 @@ class MessageListApi(AppApiResource): 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index de786f8ccf..28545a36ab 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -2,6 +2,7 @@ import json from flask import request from flask_restful import reqparse, marshal +from flask_login import current_user from sqlalchemy import desc from werkzeug.exceptions import NotFound @@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file) + upload_file = FileService.upload_file(file, current_user) data_source = { 'type': 'upload_file', 'info_list': { @@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file) + upload_file = FileService.upload_file(file, current_user) data_source = { 'type': 'upload_file', 'info_list': { diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 0808dce5c4..3fba1869ce 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api') api = ExternalApi(bp) -from . import completion, app, conversation, message, site, saved_message, audio, passport +from . import completion, app, conversation, message, site, saved_message, audio, passport, file diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index bb99e26ad1..45213c4c75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,5 +1,6 @@ # -*- coding:utf-8 -*- from flask_restful import marshal_with, fields +from flask import current_app from controllers.web import api from controllers.web.wraps import WebApiResource @@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource): 'options': fields.List(fields.String) } + system_parameters_fields = { + 'image_file_size_limit': fields.String + } + parameters_fields = { 'opening_statement': fields.String, 'suggested_questions': fields.Raw, @@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw + 'sensitive_word_avoidance': fields.Raw, + 'file_upload': fields.Raw, + 'system_parameters': fields.Nested(system_parameters_fields) } @marshal_with(parameters_fields) @@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource): 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, + 'file_upload': app_model_config.file_upload_dict, + 'system_parameters': { + 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') + } } diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 579c1761c5..b49fec5110 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -30,12 +30,14 @@ class CompletionApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False try: response = CompletionService.completion( @@ -88,6 +90,7 @@ class ChatApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') @@ -95,6 +98,7 @@ class ChatApi(WebApiResource): args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False try: response = CompletionService.completion( diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index ce089ca395..f6bb96bf18 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + end_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index c66bbd85b2..4566c323a2 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException): class ProviderNotSupportSpeechToTextError(BaseHTTPException): error_code = 'provider_not_support_speech_to_text' description = "Provider not support speech to text." - code = 400 \ No newline at end of file + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = 'no_file_uploaded' + description = "Please upload your file." + code = 400 + + +class TooManyFilesError(BaseHTTPException): + error_code = 'too_many_files' + description = "Only one file is allowed." + code = 400 + + +class FileTooLargeError(BaseHTTPException): + error_code = 'file_too_large' + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py new file mode 100644 index 0000000000..985e9c5b58 --- /dev/null +++ b/api/controllers/web/file.py @@ -0,0 +1,36 @@ +from flask import request +from flask_restful import marshal_with + +from controllers.web import api +from controllers.web.wraps import WebApiResource +from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ + UnsupportedFileTypeError +import services +from services.file_service import FileService +from fields.file_fields import file_fields + + +class FileApi(WebApiResource): + + @marshal_with(file_fields) + def post(self, app_model, end_user): + # get file from request + file = request.files['file'] + + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + try: + upload_file = FileService.upload_file(file, end_user) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 + + +api.add_resource(FileApi, '/files/upload') diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2adc1db45f..f43b7f3007 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService +from fields.conversation_fields import message_file_fields class MessageListApi(WebApiResource): @@ -54,6 +55,7 @@ class MessageListApi(WebApiResource): 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 7f6f4249c9..888032cdee 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource from libs.helper import uuid_value, TimestampField from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService +from fields.conversation_fields import message_file_fields + feedback_fields = { 'rating': fields.String @@ -18,6 +20,7 @@ message_fields = { 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'created_at': TimestampField } diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 109dda68cd..8cd8c693d0 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -11,7 +11,8 @@ from pydantic import BaseModel from core.callback_handler.entity.llm_message import LLMMessage from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ ConversationTaskInterruptException -from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage +from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \ + ImagePromptMessageFile from core.model_providers.models.llm.base import BaseLLM from core.moderation.base import ModerationOutputsResult, ModerationAction from core.moderation.factory import ModerationFactory @@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler): real_prompts.append({ "role": role, - "text": message.content + "text": message.content, + "files": [{ + "type": file.type.value, + "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:], + "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None, + } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])] }) self.llm_message.prompt = real_prompts diff --git a/api/core/completion.py b/api/core/completion.py index 0f7c140263..b4f5e36b4c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ ConversationTaskInterruptException from core.external_data_tool.factory import ExternalDataToolFactory +from core.file.file_obj import FileObj from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory from core.model_providers.model_factory import ModelFactory -from core.model_providers.models.entity.message import PromptMessage +from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile from core.model_providers.models.llm.base import BaseLLM from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_template import PromptTemplateParser @@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory class Completion: @classmethod def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, - user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, - is_override: bool = False, retriever_from: str = 'dev'): + files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation], + streaming: bool, is_override: bool = False, retriever_from: str = 'dev', + auto_generate_name: bool = True): """ errors: ProviderTokenNotInitError """ @@ -64,16 +66,21 @@ class Completion: is_override=is_override, inputs=inputs, query=query, + files=files, streaming=streaming, - model_instance=final_model_instance + model_instance=final_model_instance, + auto_generate_name=auto_generate_name ) + prompt_message_files = [file.prompt_message_file for file in files] + rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( mode=app.mode, model_instance=final_model_instance, app_model_config=app_model_config, query=query, - inputs=inputs + inputs=inputs, + files=prompt_message_files ) # init orchestrator rule parser @@ -95,6 +102,7 @@ class Completion: app_model_config=app_model_config, query=query, inputs=inputs, + files=prompt_message_files, agent_execute_result=None, conversation_message_task=conversation_message_task, memory=memory, @@ -146,6 +154,7 @@ class Completion: app_model_config=app_model_config, query=query, inputs=inputs, + files=prompt_message_files, agent_execute_result=agent_execute_result, conversation_message_task=conversation_message_task, memory=memory, @@ -257,6 +266,7 @@ class Completion: @classmethod def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, + files: List[PromptMessageFile], agent_execute_result: Optional[AgentExecuteResult], conversation_message_task: ConversationMessageTask, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], @@ -266,10 +276,12 @@ class Completion: # get llm prompt if app_model_config.prompt_type == 'simple': prompt_messages, stop_words = prompt_transform.get_prompt( - mode=mode, + app_mode=mode, + app_model_config=app_model_config, pre_prompt=app_model_config.pre_prompt, inputs=inputs, query=query, + files=files, context=agent_execute_result.output if agent_execute_result else None, memory=memory, model_instance=model_instance @@ -280,6 +292,7 @@ class Completion: app_model_config=app_model_config, inputs=inputs, query=query, + files=files, context=agent_execute_result.output if agent_execute_result else None, memory=memory, model_instance=model_instance @@ -337,7 +350,7 @@ class Completion: @classmethod def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig, - query: str, inputs: dict) -> int: + query: str, inputs: dict, files: List[PromptMessageFile]) -> int: model_limited_tokens = model_instance.model_rules.max_tokens.max max_tokens = model_instance.get_model_kwargs().max_tokens @@ -348,15 +361,16 @@ class Completion: max_tokens = 0 prompt_transform = PromptTransform() - prompt_messages = [] # get prompt without memory and context if app_model_config.prompt_type == 'simple': prompt_messages, _ = prompt_transform.get_prompt( - mode=mode, + app_mode=mode, + app_model_config=app_model_config, pre_prompt=app_model_config.pre_prompt, inputs=inputs, query=query, + files=files, context=None, memory=None, model_instance=model_instance @@ -367,6 +381,7 @@ class Completion: app_model_config=app_model_config, inputs=inputs, query=query, + files=files, context=None, memory=None, model_instance=model_instance diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 9dd211d360..a1a3affe51 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.chain_result import ChainResult +from core.file.file_obj import FileObj from core.model_providers.model_factory import ModelFactory -from core.model_providers.models.entity.message import to_prompt_messages, MessageType +from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile from core.model_providers.models.llm.base import BaseLLM from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser @@ -16,13 +17,14 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DatasetQuery from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \ - MessageChain, DatasetRetrieverResource + MessageChain, DatasetRetrieverResource, MessageFile class ConversationMessageTask: def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, - inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, - conversation: Optional[Conversation] = None, is_override: bool = False): + inputs: dict, query: str, files: List[FileObj], streaming: bool, + model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False, + auto_generate_name: bool = True): self.start_at = time.perf_counter() self.task_id = task_id @@ -35,6 +37,7 @@ class ConversationMessageTask: self.user = user self.inputs = inputs self.query = query + self.files = files self.streaming = streaming self.conversation = conversation @@ -45,6 +48,7 @@ class ConversationMessageTask: self.message = None self.retriever_resource = None + self.auto_generate_name = auto_generate_name self.model_dict = self.app_model_config.model_dict self.provider_name = self.model_dict.get('provider') @@ -100,7 +104,7 @@ class ConversationMessageTask: model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=self.mode, - name='', + name='New conversation', inputs=self.inputs, introduction=introduction, system_instruction=system_instruction, @@ -142,6 +146,19 @@ class ConversationMessageTask: db.session.add(self.message) db.session.commit() + for file in self.files: + message_file = MessageFile( + message_id=self.message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), + created_by=self.user.id + ) + db.session.add(message_file) + db.session.commit() + def append_message_text(self, text: str): if text is not None: self._pub_handler.pub_text(text) @@ -176,7 +193,8 @@ class ConversationMessageTask: message_was_created.send( self.message, conversation=self.conversation, - is_first_message=self.is_new_conversation + is_first_message=self.is_new_conversation, + auto_generate_name=self.auto_generate_name ) if not by_stopped: diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py new file mode 100644 index 0000000000..e2487e7ed1 --- /dev/null +++ b/api/core/file/file_obj.py @@ -0,0 +1,79 @@ +import enum +from typing import Optional + +from pydantic import BaseModel + +from core.file.upload_file_parser import UploadFileParser +from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile +from extensions.ext_database import db +from models.model import UploadFile + + +class FileType(enum.Enum): + IMAGE = 'image' + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(enum.Enum): + REMOTE_URL = 'remote_url' + LOCAL_FILE = 'local_file' + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileObj(BaseModel): + id: Optional[str] + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + url: Optional[str] + upload_file_id: Optional[str] + file_config: dict + + @property + def data(self) -> Optional[str]: + return self._get_data() + + @property + def preview_url(self) -> Optional[str]: + return self._get_data(force_url=True) + + @property + def prompt_message_file(self) -> PromptMessageFile: + if self.type == FileType.IMAGE: + image_config = self.file_config.get('image') + + return ImagePromptMessageFile( + data=self.data, + detail=ImagePromptMessageFile.DETAIL.HIGH + if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW + ) + + def _get_data(self, force_url: bool = False) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + upload_file = (db.session.query(UploadFile) + .filter( + UploadFile.id == self.upload_file_id, + UploadFile.tenant_id == self.tenant_id + ).first()) + + return UploadFileParser.get_image_data( + upload_file=upload_file, + force_url=force_url + ) + + return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py new file mode 100644 index 0000000000..faa1c8badb --- /dev/null +++ b/api/core/file/message_file_parser.py @@ -0,0 +1,180 @@ +from typing import List, Union, Optional, Dict + +import requests + +from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.file.upload_file_parser import SUPPORT_EXTENSIONS +from extensions.ext_database import db +from models.account import Account +from models.model import MessageFile, EndUser, AppModelConfig, UploadFile + + +class MessageFileParser: + + def __init__(self, tenant_id: str, app_id: str) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + + def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig, + user: Union[Account, EndUser]) -> List[FileObj]: + """ + validate and transform files arg + + :param files: + :param app_model_config: + :param user: + :return: + """ + file_upload_config = app_model_config.file_upload_dict + + for file in files: + if not isinstance(file, dict): + raise ValueError('Invalid file format, must be dict') + if not file.get('type'): + raise ValueError('Missing file type') + FileType.value_of(file.get('type')) + if not file.get('transfer_method'): + raise ValueError('Missing file transfer method') + FileTransferMethod.value_of(file.get('transfer_method')) + if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: + if not file.get('url'): + raise ValueError('Missing file url') + if not file.get('url').startswith('http'): + raise ValueError('Invalid file url') + if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): + raise ValueError('Missing file upload_file_id') + + # transform files to file objs + type_file_objs = self._to_file_objs(files, file_upload_config) + + # validate files + new_files = [] + for file_type, file_objs in type_file_objs.items(): + if file_type == FileType.IMAGE: + # parse and validate files + image_config = file_upload_config.get('image') + + # check if image file feature is enabled + if not image_config['enabled']: + continue + + # Validate number of files + if len(files) > image_config['number_limits']: + raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") + + for file_obj in file_objs: + # Validate transfer method + if file_obj.transfer_method.value not in image_config['transfer_methods']: + raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') + + # Validate file type + if file_obj.type != FileType.IMAGE: + raise ValueError(f'Invalid file type: {file_obj.type}') + + if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: + # check remote url valid and is image + result, error = self._check_image_remote_url(file_obj.url) + if result is False: + raise ValueError(error) + elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: + # get upload file from upload_file_id + upload_file = (db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.upload_file_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), + UploadFile.extension.in_(SUPPORT_EXTENSIONS) + ).first()) + + # check upload file is belong to tenant and user + if not upload_file: + raise ValueError('Invalid upload file') + + new_files.append(file_obj) + + # return all file objs + return new_files + + def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]: + """ + transform message files + + :param files: + :param app_model_config: + :return: + """ + # transform files to file objs + type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict) + + # return all file objs + return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] + + def _to_file_objs(self, files: List[Union[Dict, MessageFile]], + file_upload_config: dict) -> Dict[FileType, List[FileObj]]: + """ + transform files to file objs + + :param files: + :param file_upload_config: + :return: + """ + type_file_objs: Dict[FileType, List[FileObj]] = { + # Currently only support image + FileType.IMAGE: [] + } + + if not files: + return type_file_objs + + # group by file type and convert file args or message files to FileObj + for file in files: + file_obj = self._to_file_obj(file, file_upload_config) + if file_obj.type not in type_file_objs: + continue + + type_file_objs[file_obj.type].append(file_obj) + + return type_file_objs + + def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: + """ + transform file to file obj + + :param file: + :return: + """ + if isinstance(file, dict): + transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + return FileObj( + tenant_id=self.tenant_id, + type=FileType.value_of(file.get('type')), + transfer_method=transfer_method, + url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + file_config=file_upload_config + ) + else: + return FileObj( + id=file.id, + tenant_id=self.tenant_id, + type=FileType.value_of(file.type), + transfer_method=FileTransferMethod.value_of(file.transfer_method), + url=file.url, + upload_file_id=file.upload_file_id or None, + file_config=file_upload_config + ) + + def _check_image_remote_url(self, url): + try: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + + response = requests.head(url, headers=headers, allow_redirects=True) + if response.status_code == 200: + return True, "" + else: + return False, "URL does not exist." + except requests.RequestException as e: + return False, f"Error checking URL: {e}" diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py new file mode 100644 index 0000000000..92335c297f --- /dev/null +++ b/api/core/file/upload_file_parser.py @@ -0,0 +1,79 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from typing import Optional + +from flask import current_app + +from extensions.ext_storage import storage + +SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif'] + + +class UploadFileParser: + @classmethod + def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: + if not upload_file: + return None + + if upload_file.extension not in SUPPORT_EXTENSIONS: + return None + + if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url: + return cls.get_signed_temp_image_url(upload_file) + else: + # get image file base64 + try: + data = storage.load(upload_file.key) + except FileNotFoundError: + logging.error(f'File not found: {upload_file.key}') + return None + + encoded_string = base64.b64encode(data).decode('utf-8') + return f'data:{upload_file.mime_type};base64,{encoded_string}' + + @classmethod + def get_signed_temp_image_url(cls, upload_file) -> str: + """ + get signed url from upload file + + :param upload_file: UploadFile object + :return: + """ + base_url = current_app.config.get('FILES_URL') + image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview' + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}" + secret_key = current_app.config['SECRET_KEY'].encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @classmethod + def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + + :param upload_file_id: file id + :param timestamp: timestamp + :param nonce: nonce + :param sign: signature + :return: + """ + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = current_app.config['SECRET_KEY'].encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= 300 # expired after 5 minutes diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index a6699f32d7..87a934e55d 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: @classmethod - def generate_conversation_name(cls, tenant_id: str, query, answer): + def generate_conversation_name(cls, tenant_id: str, query): prompt = CONVERSATION_TITLE_PROMPT if len(query) > 2000: @@ -40,8 +40,12 @@ class LLMGenerator: result_dict = json.loads(answer) answer = result_dict['Your Output'] + name = answer.strip() - return answer.strip() + if len(name) > 75: + name = name[:75] + '...' + + return name @classmethod def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py index 755df1201a..4123521ae3 100644 --- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py +++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py @@ -3,6 +3,7 @@ from typing import Any, List, Dict from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import get_buffer_string, BaseMessage +from core.file.message_file_parser import MessageFileParser from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages from core.model_providers.models.llm.base import BaseLLM from extensions.ext_database import db @@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): @property def buffer(self) -> List[BaseMessage]: """String buffer of memory.""" + app_model = self.conversation.app + # fetch limited messages desc, and return reversed messages = db.session.query(Message).filter( Message.conversation_id == self.conversation.id, @@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): ).order_by(Message.created_at.desc()).limit(self.message_limit).all() messages = list(reversed(messages)) + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id) chat_messages: List[PromptMessage] = [] for message in messages: - chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER)) + files = message.message_files + if files: + file_objs = message_file_parser.transform_message_files( + files, message.app_model_config + ) + + prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs] + chat_messages.append(PromptMessage( + content=message.query, + type=MessageType.USER, + files=prompt_message_files + )) + else: + chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER)) + chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) if not chat_messages: diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py index 1ae04d67f5..e3e49ba0f4 100644 --- a/api/core/model_providers/models/entity/message.py +++ b/api/core/model_providers/models/entity/message.py @@ -1,4 +1,5 @@ import enum +from typing import Any, cast, Union, List, Dict from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage from pydantic import BaseModel @@ -18,17 +19,53 @@ class MessageType(enum.Enum): SYSTEM = 'system' +class PromptMessageFileType(enum.Enum): + IMAGE = 'image' + + @staticmethod + def value_of(value): + for member in PromptMessageFileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + + +class PromptMessageFile(BaseModel): + type: PromptMessageFileType + data: Any + + +class ImagePromptMessageFile(PromptMessageFile): + class DETAIL(enum.Enum): + LOW = 'low' + HIGH = 'high' + + type: PromptMessageFileType = PromptMessageFileType.IMAGE + detail: DETAIL = DETAIL.LOW + + class PromptMessage(BaseModel): type: MessageType = MessageType.USER content: str = '' + files: list[PromptMessageFile] = [] function_call: dict = None +class LCHumanMessageWithFiles(HumanMessage): + # content: Union[str, List[Union[str, Dict]]] + content: str + files: list[PromptMessageFile] + + def to_lc_messages(messages: list[PromptMessage]): lc_messages = [] for message in messages: if message.type == MessageType.USER: - lc_messages.append(HumanMessage(content=message.content)) + if not message.files: + lc_messages.append(HumanMessage(content=message.content)) + else: + lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files)) elif message.type == MessageType.ASSISTANT: additional_kwargs = {} if message.function_call: @@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]): prompt_messages = [] for message in messages: if isinstance(message, HumanMessage): - prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) + if isinstance(message, LCHumanMessageWithFiles): + prompt_messages.append(PromptMessage( + content=message.content, + type=MessageType.USER, + files=message.files + )) + else: + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) elif isinstance(message, AIMessage): message_kwargs = { 'content': message.content, diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index 04759e63bc..08331885b8 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -1,11 +1,9 @@ -import decimal import logging from typing import List, Optional, Any import openai from langchain.callbacks.manager import Callbacks from langchain.schema import LLMResult -from openai import api_requestor from core.model_providers.providers.base import BaseModelProvider from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 62fd814678..cdfe08e9b0 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import BaseMessage from core.model_providers.models.entity.model_params import ModelMode -from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages +from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.baichuan_model import BaichuanModel from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel @@ -16,32 +16,59 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.models.llm.xinference_model import XinferenceModel from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser +from models.model import AppModelConfig + class AppMode(enum.Enum): COMPLETION = 'completion' CHAT = 'chat' + class PromptTransform: - def get_prompt(self, mode: str, - pre_prompt: str, inputs: dict, + def get_prompt(self, + app_mode: str, + app_model_config: AppModelConfig, + pre_prompt: str, + inputs: dict, query: str, + files: List[PromptMessageFile], context: Optional[str], memory: Optional[BaseChatMemory], model_instance: BaseLLM) -> \ Tuple[List[PromptMessage], Optional[List[str]]]: - prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance)) - prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance) - return [PromptMessage(content=prompt)], stops + model_mode = app_model_config.model_dict['mode'] + + app_mode_enum = AppMode(app_mode) + model_mode_enum = ModelMode(model_mode) + + prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance)) + + if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT: + stops = None + + prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs, + query, context, memory, + model_instance, files) + else: + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context, + memory, + model_instance, files) + return prompt_messages, stops + + def get_advanced_prompt(self, + app_mode: str, + app_model_config: AppModelConfig, + inputs: dict, + query: str, + files: List[PromptMessageFile], + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: - def get_advanced_prompt(self, - app_mode: str, - app_model_config: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory], - model_instance: BaseLLM) -> List[PromptMessage]: - model_mode = app_model_config.model_dict['mode'] app_mode_enum = AppMode(app_mode) @@ -51,15 +78,20 @@ class PromptTransform: if app_mode_enum == AppMode.CHAT: if model_mode_enum == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) + prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, + files, context, memory, + model_instance) elif model_mode_enum == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) + prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files, + context, memory, model_instance) elif app_mode_enum == AppMode.COMPLETION: if model_mode_enum == ModelMode.CHAT: - prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context) + prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, + files, context) elif model_mode_enum == ModelMode.COMPLETION: - prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context) - + prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, + files, context) + return prompt_messages def _get_history_messages_from_memory(self, memory: BaseChatMemory, @@ -71,7 +103,7 @@ class PromptTransform: return external_context[memory_key] def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, - max_token_limit: int) -> List[PromptMessage]: + max_token_limit: int) -> List[PromptMessage]: """Get memory messages.""" memory.max_token_limit = max_token_limit memory.return_messages = True @@ -79,7 +111,7 @@ class PromptTransform: external_context = memory.load_memory_variables({}) memory.return_messages = False return to_prompt_messages(external_context[memory_key]) - + def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str: # baichuan if isinstance(model_instance, BaichuanModel): @@ -94,13 +126,13 @@ class PromptTransform: return 'common_completion' else: return 'common_chat' - + def _prompt_file_name_for_baichuan(self, mode: str) -> str: if mode == 'completion': return 'baichuan_completion' else: return 'baichuan_chat' - + def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: # Get the absolute path of the subdirectory prompt_path = os.path.join( @@ -111,12 +143,53 @@ class PromptTransform: # Open the JSON file and read its content with open(json_file_path, 'r') as json_file: return json.load(json_file) - - def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory], - model_instance: BaseLLM) -> Tuple[str, Optional[list]]: + + def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM, + files: List[PromptMessageFile]) -> List[PromptMessage]: + prompt_messages = [] + + context_prompt_content = '' + if context and 'context_prompt' in prompt_rules: + prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) + context_prompt_content = prompt_template.format( + {'context': context} + ) + + pre_prompt_content = '' + if pre_prompt: + prompt_template = PromptTemplateParser(template=pre_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + pre_prompt_content = prompt_template.format( + prompt_inputs + ) + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt': + prompt += context_prompt_content + elif order == 'pre_prompt': + prompt += pre_prompt_content + + prompt = re.sub(r'<\|.*?\|>', '', prompt) + + prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt)) + + self._append_chat_histories(memory, prompt_messages, model_instance) + + prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files)) + + return prompt_messages + + def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM, + files: List[PromptMessageFile]) -> List[PromptMessage]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) @@ -175,16 +248,12 @@ class PromptTransform: prompt = re.sub(r'<\|.*?\|>', '', prompt) - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None + return [PromptMessage(content=prompt, files=files)] - return prompt, stops - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: if '#context#' in prompt_template.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs['#context#'] = context else: prompt_inputs['#context#'] = '' @@ -195,17 +264,18 @@ class PromptTransform: else: prompt_inputs['#query#'] = '' - def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, - prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None: + def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, + prompt_template: PromptTemplateParser, prompt_inputs: dict, + model_instance: BaseLLM) -> None: if '#histories#' in prompt_template.variable_keys: if memory: tmp_human_message = PromptBuilder.to_human_message( prompt_content=raw_prompt, - inputs={ '#histories#': '', **prompt_inputs } + inputs={'#histories#': '', **prompt_inputs} ) rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance) - + memory.human_prefix = conversation_histories_role['user_prefix'] memory.ai_prefix = conversation_histories_role['assistant_prefix'] histories = self._get_history_messages_from_memory(memory, rest_tokens) @@ -213,7 +283,8 @@ class PromptTransform: else: prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None: + def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], + model_instance: BaseLLM) -> None: if memory: rest_tokens = self._calculate_rest_token(prompt_messages, model_instance) @@ -242,19 +313,19 @@ class PromptTransform: return prompt def _get_chat_app_completion_model_prompt_messages(self, - app_model_config: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory], - model_instance: BaseLLM) -> List[PromptMessage]: - + app_model_config: AppModelConfig, + inputs: dict, + query: str, + files: List[PromptMessageFile], + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: + raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text'] conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] prompt_messages = [] - prompt = '' - + prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -262,28 +333,29 @@ class PromptTransform: self._set_query_variable(query, prompt_template, prompt_inputs) - self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance) + self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, + model_instance) prompt = self._format_prompt(prompt_template, prompt_inputs) - prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt)) + prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files)) return prompt_messages def _get_chat_app_chat_model_prompt_messages(self, - app_model_config: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory], - model_instance: BaseLLM) -> List[PromptMessage]: + app_model_config: AppModelConfig, + inputs: dict, + query: str, + files: List[PromptMessageFile], + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] prompt_messages = [] for prompt_item in raw_prompt_list: raw_prompt = prompt_item['text'] - prompt = '' prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -292,23 +364,23 @@ class PromptTransform: prompt = self._format_prompt(prompt_template, prompt_inputs) - prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) - + prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt)) + self._append_chat_histories(memory, prompt_messages, model_instance) - prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query)) + prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files)) return prompt_messages def _get_completion_app_completion_model_prompt_messages(self, - app_model_config: str, - inputs: dict, - context: Optional[str]) -> List[PromptMessage]: + app_model_config: AppModelConfig, + inputs: dict, + files: List[PromptMessageFile], + context: Optional[str]) -> List[PromptMessage]: raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text'] prompt_messages = [] - prompt = '' - + prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -316,21 +388,21 @@ class PromptTransform: prompt = self._format_prompt(prompt_template, prompt_inputs) - prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt)) + prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files)) return prompt_messages def _get_completion_app_chat_model_prompt_messages(self, - app_model_config: str, - inputs: dict, - context: Optional[str]) -> List[PromptMessage]: + app_model_config: AppModelConfig, + inputs: dict, + files: List[PromptMessageFile], + context: Optional[str]) -> List[PromptMessage]: raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] prompt_messages = [] for prompt_item in raw_prompt_list: raw_prompt = prompt_item['text'] - prompt = '' prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -339,6 +411,11 @@ class PromptTransform: prompt = self._format_prompt(prompt_template, prompt_inputs) - prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) - - return prompt_messages \ No newline at end of file + prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt)) + + for prompt_message in prompt_messages[::-1]: + if prompt_message.type == MessageType.USER: + prompt_message.files = files + break + + return prompt_messages diff --git a/api/core/third_party/langchain/llms/chat_open_ai.py b/api/core/third_party/langchain/llms/chat_open_ai.py index 93c31526d1..33131ffc38 100644 --- a/api/core/third_party/langchain/llms/chat_open_ai.py +++ b/api/core/third_party/langchain/llms/chat_open_ai.py @@ -1,10 +1,13 @@ import os -from typing import Dict, Any, Optional, Union, Tuple +from typing import Dict, Any, Optional, Union, Tuple, List, cast from langchain.chat_models import ChatOpenAI +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage from pydantic import root_validator +from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile + class EnhanceChatOpenAI(ChatOpenAI): request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) @@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI): "api_key": self.openai_api_key, "organization": self.openai_organization if self.openai_organization else None, } + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = self._client_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + message_dicts = [self._convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + model, encoding = self._get_encoding_model() + if model.startswith("gpt-3.5-turbo-0301"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): + tokens_per_message = 3 + tokens_per_name = 1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [self._convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item['text'] + + value = text + num_tokens += len(encoding.encode(str(value))) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens + + def _convert_message_to_dict(self, message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, LCHumanMessageWithFiles): + content = [ + { + "type": "text", + "text": message.content + } + ] + + for file in message.files: + if file.type == PromptMessageFileType.IMAGE: + file = cast(ImagePromptMessageFile, file) + content.append({ + "type": "image_url", + "image_url": { + "url": file.data, + "detail": file.detail.value + } + }) + + message_dict = {"role": "user", "content": content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 4176f9dbbb..b35e67969b 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,5 +1,3 @@ -import logging - from core.generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db @@ -10,8 +8,9 @@ def handle(sender, **kwargs): message = sender conversation = kwargs.get('conversation') is_first_message = kwargs.get('is_first_message') + auto_generate_name = kwargs.get('auto_generate_name', True) - if is_first_message: + if auto_generate_name and is_first_message: if conversation.mode == 'chat': app_model = conversation.app if not app_model: @@ -19,14 +18,9 @@ def handle(sender, **kwargs): # generate conversation name try: - name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer) - - if len(name) > 75: - name = name[:75] + '...' - + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) conversation.name = name except: - conversation.name = 'New conversation' + pass - db.session.add(conversation) db.session.commit() diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index dc44892024..f591e8173e 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,6 +1,7 @@ import os import shutil from contextlib import closing +from typing import Union, Generator import boto3 from botocore.exceptions import ClientError @@ -45,7 +46,13 @@ class Storage: with open(os.path.join(os.getcwd(), filename), "wb") as f: f.write(data) - def load(self, filename): + def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]: + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) + + def load_once(self, filename: str) -> bytes: if self.storage_type == 's3': try: with closing(self.client) as client: @@ -69,6 +76,34 @@ class Storage: return data + def load_stream(self, filename: str) -> Generator: + def generate(filename: str = filename) -> Generator: + if self.storage_type == 's3': + try: + with closing(self.client) as client: + response = client.get_object(Bucket=self.bucket_name, Key=filename) + for chunk in response['Body'].iter_chunks(): + yield chunk + except ClientError as ex: + if ex.response['Error']['Code'] == 'NoSuchKey': + raise FileNotFoundError("File not found") + else: + raise + else: + if not self.folder or self.folder.endswith('/'): + filename = self.folder + filename + else: + filename = self.folder + '/' + filename + + if not os.path.exists(filename): + raise FileNotFoundError("File not found") + + with open(filename, "rb") as f: + while chunk := f.read(4096): # Read in chunks of 4KB + yield chunk + + return generate() + def download(self, filename, target_filepath): if self.storage_type == 's3': with closing(self.client) as client: diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 154b10ceb6..2c8b5eb109 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -32,7 +32,8 @@ model_config_fields = { 'prompt_type': fields.String, 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), - 'dataset_configs': fields.Raw(attribute='dataset_configs_dict') + 'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), + 'file_upload': fields.Raw(attribute='file_upload_dict'), } app_detail_fields = { @@ -140,4 +141,4 @@ app_site_fields = { 'privacy_policy': fields.String, 'customize_token_strategy': fields.String, 'prompt_public': fields.Boolean -} \ No newline at end of file +} diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index df43a62fb6..49a96c2751 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -28,6 +28,12 @@ annotation_fields = { 'created_at': TimestampField } +message_file_fields = { + 'id': fields.String, + 'type': fields.String, + 'url': fields.String, +} + message_detail_fields = { 'id': fields.String, 'conversation_id': fields.String, @@ -43,7 +49,8 @@ message_detail_fields = { 'from_account_id': fields.String, 'feedbacks': fields.List(fields.Nested(feedback_fields)), 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'created_at': TimestampField + 'created_at': TimestampField, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), } feedback_stat_fields = { @@ -111,11 +118,6 @@ conversation_message_detail_fields = { 'message': fields.Nested(message_detail_fields, attribute='first_message'), } -simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, -} - conversation_with_summary_fields = { 'id': fields.String, 'status': fields.String, @@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = { 'limit': fields.Integer, 'has_more': fields.Boolean, 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) -} \ No newline at end of file +} diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index dcab11a7ad..2ef379dabc 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -4,7 +4,8 @@ from libs.helper import TimestampField upload_config_fields = { 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer + 'batch_count_limit': fields.Integer, + 'image_file_size_limit': fields.Integer, } file_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c2f7193f65..df0a1104fe 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,6 +1,7 @@ from flask_restful import fields from libs.helper import TimestampField +from fields.conversation_fields import message_file_fields feedback_fields = { 'rating': fields.String @@ -31,6 +32,7 @@ message_fields = { 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py new file mode 100644 index 0000000000..7aed3c5e6c --- /dev/null +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -0,0 +1,59 @@ +"""add gpt4v supports + +Revision ID: 8fe468ba0ca5 +Revises: a9836e3baeee +Create Date: 2023-11-09 11:39:00.006432 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8fe468ba0ca5' +down_revision = 'a9836e3baeee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('message_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('transfer_method', sa.String(length=255), nullable=False), + sa.Column('url', sa.Text(), nullable=True), + sa.Column('upload_file_id', postgresql.UUID(), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_file_pkey') + ) + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False) + batch_op.create_index('message_file_message_idx', ['message_id'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True)) + + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('created_by_role') + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('file_upload') + + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.drop_index('message_file_message_idx') + batch_op.drop_index('message_file_created_by_idx') + + op.drop_table('message_files') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 2a963cd8b2..b7cd428839 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,10 +1,10 @@ import json -from json import JSONDecodeError from flask import current_app, request from flask_login import UserMixin from sqlalchemy.dialects.postgresql import UUID +from core.file.upload_file_parser import UploadFileParser from libs.helper import generate_string from extensions.ext_database import db from .account import Account, Tenant @@ -98,6 +98,7 @@ class AppModelConfig(db.Model): completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) + file_upload = db.Column(db.Text) @property def app(self): @@ -161,6 +162,10 @@ class AppModelConfig(db.Model): def dataset_configs_dict(self) -> dict: return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} + @property + def file_upload_dict(self) -> dict: + return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} + def to_dict(self) -> dict: return { "provider": "", @@ -182,7 +187,8 @@ class AppModelConfig(db.Model): "prompt_type": self.prompt_type, "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, - "dataset_configs": self.dataset_configs_dict + "dataset_configs": self.dataset_configs_dict, + "file_upload": self.file_upload_dict } def from_model_config_dict(self, model_config: dict): @@ -213,6 +219,8 @@ class AppModelConfig(db.Model): if model_config.get('completion_prompt_config') else None self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ if model_config.get('dataset_configs') else None + self.file_upload = json.dumps(model_config.get('file_upload')) \ + if model_config.get('file_upload') else None return self def copy(self): @@ -238,7 +246,8 @@ class AppModelConfig(db.Model): prompt_type=self.prompt_type, chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, - dataset_configs=self.dataset_configs + dataset_configs=self.dataset_configs, + file_upload=self.file_upload ) return new_app_model_config @@ -512,6 +521,37 @@ class Message(db.Model): return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ .order_by(DatasetRetrieverResource.position.asc()).all() + @property + def message_files(self): + return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + + @property + def files(self): + message_files = self.message_files + + files = [] + for message_file in message_files: + url = message_file.url + if message_file.type == 'image': + if message_file.transfer_method == 'local_file': + upload_file = (db.session.query(UploadFile) + .filter( + UploadFile.id == message_file.upload_file_id + ).first()) + + url = UploadFileParser.get_image_data( + upload_file=upload_file, + force_url=True + ) + + files.append({ + 'id': message_file.id, + 'type': message_file.type, + 'url': url + }) + + return files + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' @@ -540,6 +580,25 @@ class MessageFeedback(db.Model): return account +class MessageFile(db.Model): + __tablename__ = 'message_files' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_file_pkey'), + db.Index('message_file_message_idx', 'message_id'), + db.Index('message_file_created_by_idx', 'created_by') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + transfer_method = db.Column(db.String(255), nullable=False) + url = db.Column(db.Text, nullable=True) + upload_file_id = db.Column(UUID, nullable=True) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + class MessageAnnotation(db.Model): __tablename__ = 'message_annotations' __table_args__ = ( @@ -683,6 +742,7 @@ class UploadFile(db.Model): size = db.Column(db.Integer, nullable=False) extension = db.Column(db.String(255), nullable=False) mime_type = db.Column(db.String(255), nullable=True) + created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -783,4 +843,3 @@ class DatasetRetrieverResource(db.Model): retriever_from = db.Column(db.Text, nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) - diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 00260708b9..be7947d7f6 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -315,6 +315,9 @@ class AppModelConfigService: # moderation validation cls.is_moderation_valid(tenant_id, config) + # file upload validation + cls.is_file_upload_valid(config) + # Filter out extra parameters filtered_config = { "opening_statement": config["opening_statement"], @@ -338,7 +341,8 @@ class AppModelConfigService: "prompt_type": config["prompt_type"], "chat_prompt_config": config["chat_prompt_config"], "completion_prompt_config": config["completion_prompt_config"], - "dataset_configs": config["dataset_configs"] + "dataset_configs": config["dataset_configs"], + "file_upload": config["file_upload"] } return filtered_config @@ -371,6 +375,34 @@ class AppModelConfigService: config=config ) + @classmethod + def is_file_upload_valid(cls, config: dict): + if 'file_upload' not in config or not config["file_upload"]: + config["file_upload"] = {} + + if not isinstance(config["file_upload"], dict): + raise ValueError("file_upload must be of dict type") + + # check image config + if 'image' not in config["file_upload"] or not config["file_upload"]["image"]: + config["file_upload"]["image"] = {"enabled": False} + + if config['file_upload']['image']['enabled']: + number_limits = config['file_upload']['image']['number_limits'] + if number_limits < 1 or number_limits > 6: + raise ValueError("number_limits must be in [1, 6]") + + detail = config['file_upload']['image']['detail'] + if detail not in ['high', 'low']: + raise ValueError("detail must be in ['high', 'low']") + + transfer_methods = config['file_upload']['image']['transfer_methods'] + if not isinstance(transfer_methods, list): + raise ValueError("transfer_methods must be of list type") + for method in transfer_methods: + if method not in ['remote_url', 'local_file']: + raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + @classmethod def is_external_data_tools_valid(cls, tenant_id: str, config: dict): if 'external_data_tools' not in config or not config["external_data_tools"]: diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 81e35c8593..280bdf7696 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -3,7 +3,7 @@ import logging import threading import time import uuid -from typing import Generator, Union, Any, Optional +from typing import Generator, Union, Any, Optional, List from flask import current_app, Flask from redis.client import PubSub @@ -12,9 +12,11 @@ from sqlalchemy import and_ from core.completion import Completion from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \ ConversationTaskInterruptException +from core.file.message_file_parser import MessageFileParser from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ LLMRateLimitError, \ LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.model_providers.models.entity.message import PromptMessageFile from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message @@ -35,6 +37,9 @@ class CompletionService: # is streaming mode inputs = args['inputs'] query = args['query'] + files = args['files'] if 'files' in args and args['files'] else [] + auto_generate_name = args['auto_generate_name'] \ + if 'auto_generate_name' in args else True if app_model.mode != 'completion' and not query: raise ValueError('query is required') @@ -132,6 +137,14 @@ class CompletionService: # clean input by app_model_config form rules inputs = cls.get_cleaned_inputs(inputs, app_model_config) + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + app_model_config, + user + ) + generate_task_id = str(uuid.uuid4()) pubsub = redis_client.pubsub() @@ -146,17 +159,20 @@ class CompletionService: 'app_model_config': app_model_config.copy(), 'query': query, 'inputs': inputs, + 'files': file_objs, 'detached_user': user, 'detached_conversation': conversation, 'streaming': streaming, 'is_model_config_override': is_model_config_override, - 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' + 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev', + 'auto_generate_name': auto_generate_name }) generate_worker_thread.start() # wait for 10 minutes to close the thread - cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, + generate_task_id) return cls.compact_response(pubsub, streaming) @@ -172,10 +188,12 @@ class CompletionService: return user @classmethod - def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig, - query: str, inputs: dict, detached_user: Union[Account, EndUser], + def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, + app_model_config: AppModelConfig, + query: str, inputs: dict, files: List[PromptMessageFile], + detached_user: Union[Account, EndUser], detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, - retriever_from: str = 'dev'): + retriever_from: str = 'dev', auto_generate_name: bool = True): with flask_app.app_context(): # fixed the state of the model object when it detached from the original session user = db.session.merge(detached_user) @@ -195,10 +213,12 @@ class CompletionService: query=query, inputs=inputs, user=user, + files=files, conversation=conversation, streaming=streaming, is_override=is_model_config_override, - retriever_from=retriever_from + retriever_from=retriever_from, + auto_generate_name=auto_generate_name ) except (ConversationTaskInterruptException, ConversationTaskStoppedException): pass @@ -215,7 +235,8 @@ class CompletionService: db.session.commit() @classmethod - def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread: + def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, + generate_task_id) -> threading.Thread: # wait for 10 minutes to close the thread timeout = 600 @@ -274,6 +295,12 @@ class CompletionService: model_dict['completion_params'] = completion_params app_model_config.model = json.dumps(model_dict) + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_objs = message_file_parser.transform_message_files( + message.files, app_model_config + ) + generate_task_id = str(uuid.uuid4()) pubsub = redis_client.pubsub() @@ -288,11 +315,13 @@ class CompletionService: 'app_model_config': app_model_config.copy(), 'query': message.query, 'inputs': message.inputs, + 'files': file_objs, 'detached_user': user, 'detached_conversation': None, 'streaming': streaming, 'is_model_config_override': True, - 'retriever_from': retriever_from + 'retriever_from': retriever_from, + 'auto_generate_name': False }) generate_worker_thread.start() @@ -388,7 +417,8 @@ class CompletionService: if event == 'message': yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" elif event == 'message_replace': - yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n" + yield "data: " + json.dumps( + cls.get_message_replace_response_data(result.get('data'))) + "\n\n" elif event == 'chain': yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" elif event == 'agent_thought': diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 45f18ec4cf..0872a232f0 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,17 +1,20 @@ from typing import Union, Optional +from core.generator.llm_generator import LLMGenerator from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db from models.account import Account -from models.model import Conversation, App, EndUser +from models.model import Conversation, App, EndUser, Message from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError +from services.errors.message import MessageNotExistsError class ConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], last_id: Optional[str], limit: int, - include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + include_ids: Optional[list] = None, exclude_ids: Optional[list] = None, + exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -29,6 +32,9 @@ class ConversationService: if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + if exclude_debug_conversation: + base_query = base_query.filter(Conversation.override_model_configs == None) + if last_id: last_conversation = base_query.filter( Conversation.id == last_id, @@ -63,10 +69,36 @@ class ConversationService: @classmethod def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account | EndUser]], name: str): + user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool): conversation = cls.get_conversation(app_model, conversation_id, user) - conversation.name = name + if auto_generate: + return cls.auto_generate_name(app_model, conversation) + else: + conversation.name = name + db.session.commit() + + return conversation + + @classmethod + def auto_generate_name(cls, app_model: App, conversation: Conversation): + # get conversation first message + message = db.session.query(Message) \ + .filter( + Message.app_id == app_model.id, + Message.conversation_id == conversation.id + ).order_by(Message.created_at.asc()).first() + + if not message: + raise MessageNotExistsError() + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) + conversation.name = name + except: + pass + db.session.commit() return conversation diff --git a/api/services/file_service.py b/api/services/file_service.py index 79e53738e0..aaf82b57c8 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,46 +1,62 @@ import datetime import hashlib -import time import uuid +from typing import Generator, Tuple, Union -from cachetools import TTLCache -from flask import request, current_app +from flask import current_app from flask_login import current_user from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from core.data_loader.file_extractor import FileExtractor +from core.file.upload_file_parser import UploadFileParser from extensions.ext_storage import storage from extensions.ext_database import db -from models.model import UploadFile +from models.account import Account +from models.model import UploadFile, EndUser from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv', + 'jpg', 'jpeg', 'png', 'webp', 'gif'] +IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif'] PREVIEW_WORDS_LIMIT = 3000 -cache = TTLCache(maxsize=None, ttl=30) class FileService: @staticmethod - def upload_file(file: FileStorage) -> UploadFile: + def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: + extension = file.filename.split('.')[-1] + if extension.lower() not in ALLOWED_EXTENSIONS: + raise UnsupportedFileTypeError() + elif only_image and extension.lower() not in IMAGE_EXTENSIONS: + raise UnsupportedFileTypeError() + # read file content file_content = file.read() + # get file size file_size = len(file_content) - file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 + if extension.lower() in IMAGE_EXTENSIONS: + file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024 + else: + file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 + if file_size > file_size_limit: message = f'File size exceeded. {file_size} > {file_size_limit}' raise FileTooLargeError(message) - extension = file.filename.split('.')[-1] - if extension.lower() not in ALLOWED_EXTENSIONS: - raise UnsupportedFileTypeError() - # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension + + if isinstance(user, Account): + current_tenant_id = user.current_tenant_id + else: + # end_user + current_tenant_id = user.tenant_id + + file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension # save file to storage storage.save(file_key, file_content) @@ -48,14 +64,15 @@ class FileService: # save file to db config = current_app.config upload_file = UploadFile( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, storage_type=config['STORAGE_TYPE'], key=file_key, name=file.filename, size=file_size, extension=extension, mime_type=file.mimetype, - created_by=current_user.id, + created_by_role=('account' if isinstance(user, Account) else 'end_user'), + created_by=user.id, created_at=datetime.datetime.utcnow(), used=False, hash=hashlib.sha3_256(file_content).hexdigest() @@ -99,12 +116,6 @@ class FileService: @staticmethod def get_file_preview(file_id: str) -> str: - # get file storage key - key = file_id + request.path - cached_response = cache.get(key) - if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: - return cached_response['response'] - upload_file = db.session.query(UploadFile) \ .filter(UploadFile.id == file_id) \ .first() @@ -121,3 +132,25 @@ class FileService: text = text[0:PREVIEW_WORDS_LIMIT] if text else '' return text + + @staticmethod + def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]: + result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile) \ + .filter(UploadFile.id == file_id) \ + .first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + # extract text from file + extension = upload_file.extension + if extension.lower() not in IMAGE_EXTENSIONS: + raise UnsupportedFileTypeError() + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file.mime_type diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 231083db19..3c521909dc 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,7 +11,8 @@ from services.conversation_service import ConversationService class WebConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], - last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination: + last_id: Optional[str], limit: int, pinned: Optional[bool] = None, + exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: @@ -32,7 +33,8 @@ class WebConversationService: last_id=last_id, limit=limit, include_ids=include_ids, - exclude_ids=exclude_ids + exclude_ids=exclude_ids, + exclude_debug_conversation=exclude_debug_conversation ) @classmethod diff --git a/api/tests/integration_tests/models/llm/test_openai_model.py b/api/tests/integration_tests/models/llm/test_openai_model.py index e6044c0bb5..e74836e9d3 100644 --- a/api/tests/integration_tests/models/llm/test_openai_model.py +++ b/api/tests/integration_tests/models/llm/test_openai_model.py @@ -5,7 +5,7 @@ from unittest.mock import patch from langchain.schema import Generation, ChatGeneration, AIMessage from core.model_providers.providers.openai_provider import OpenAIProvider -from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile from core.model_providers.models.entity.model_params import ModelKwargs from core.model_providers.models.llm.openai_model import OpenAIModel from models.provider import Provider, ProviderType @@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt): assert rst == 22 +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_vision_chat_get_num_tokens(mock_decrypt): + openai_model = get_mock_openai_model('gpt-4-vision-preview') + messages = [ + PromptMessage(content='What’s in first image?', files=[ + ImageMessageFile( + data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg') + ]) + ] + rst = openai_model.get_num_tokens(messages) + assert rst == 77 + @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) def test_run(mock_decrypt, mocker): mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) @@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker): messages, stop=['\nHuman:'], ) + assert (len(rst.content) > 0) + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_vision_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + openai_model = get_mock_openai_model('gpt-4-vision-preview') + messages = [ + PromptMessage(content='What’s in first image?', files=[ + ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg') + ]) + ] + rst = openai_model.run( + messages, + ) assert len(rst.content) > 0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 5031d489fb..7463a17ec9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -19,18 +19,22 @@ services: # different from api or web app domain. # example: http://cloud.dify.ai CONSOLE_API_URL: '' - # The URL for Service API endpoints, refers to the base URL of the current API service if api domain is + # The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is # different from console domain. # example: http://api.dify.ai SERVICE_API_URL: '' - # The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from + # The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from # console or api domain. # example: http://udify.app APP_API_URL: '' - # The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from + # The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from # console or api domain. # example: http://udify.app APP_WEB_URL: '' + # File preview or download Url prefix. + # used to display File preview or download Url to the front-end or as Multi-model inputs; + # Url is signed and has expiration time. + FILES_URL: '' # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed. MIGRATION_ENABLED: 'true' # The configurations of postgres database connection. diff --git a/docker/nginx/conf.d/default.conf b/docker/nginx/conf.d/default.conf index 279de0c328..879ce63164 100644 --- a/docker/nginx/conf.d/default.conf +++ b/docker/nginx/conf.d/default.conf @@ -17,6 +17,11 @@ server { include proxy.conf; } + location /files { + proxy_pass http://api:5001; + include proxy.conf; + } + location / { proxy_pass http://web:3000; include proxy.conf;