From 3571292fbf7790a5261d7f5f3760c2264a6bc170 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 15 Aug 2024 12:54:05 +0800 Subject: [PATCH] chore(api): Introduce Ruff Formatter. (#7291) --- .github/workflows/style.yml | 4 + api/app.py | 161 +++---- api/commands.py | 415 +++++++++--------- api/constants/__init__.py | 2 +- api/constants/languages.py | 39 +- api/constants/model_template.py | 99 ++--- api/contexts/__init__.py | 4 +- api/events/app_event.py | 8 +- api/events/dataset_event.py | 2 +- api/events/document_event.py | 2 +- .../clean_when_dataset_deleted.py | 10 +- .../clean_when_document_deleted.py | 6 +- .../event_handlers/create_document_index.py | 24 +- .../create_installed_app_when_app_created.py | 2 +- .../create_site_record_when_app_created.py | 10 +- .../deduct_quota_when_messaeg_created.py | 8 +- ...rameters_cache_when_sync_draft_workflow.py | 6 +- .../event_handlers/document_index_event.py | 2 +- ...aset_join_when_app_model_config_updated.py | 26 +- ...oin_when_app_published_workflow_updated.py | 23 +- ...vider_last_used_at_when_messaeg_created.py | 6 +- api/events/message_event.py | 2 +- api/events/tenant_event.py | 4 +- api/extensions/ext_celery.py | 23 +- api/extensions/ext_compress.py | 11 +- api/extensions/ext_database.py | 10 +- api/extensions/ext_mail.py | 72 +-- api/extensions/ext_redis.py | 27 +- api/extensions/ext_sentry.py | 17 +- api/extensions/ext_storage.py | 38 +- api/extensions/storage/aliyun_storage.py | 13 +- api/extensions/storage/azure_storage.py | 18 +- api/extensions/storage/base_storage.py | 5 +- api/extensions/storage/google_storage.py | 15 +- api/extensions/storage/local_storage.py | 29 +- api/extensions/storage/oci_storage.py | 25 +- api/extensions/storage/s3_storage.py | 33 +- api/extensions/storage/tencent_storage.py | 19 +- api/fields/annotation_fields.py | 6 +- api/fields/api_based_extension_fields.py | 14 +- api/fields/app_fields.py | 226 +++++----- api/fields/conversation_fields.py | 275 ++++++------ api/fields/conversation_variable_fields.py | 24 +- api/fields/data_source_fields.py | 72 ++- api/fields/dataset_fields.py | 99 ++--- api/fields/document_fields.py | 120 +++-- api/fields/end_user_fields.py | 8 +- api/fields/file_fields.py | 22 +- api/fields/hit_testing_fields.py | 60 +-- api/fields/installed_app_fields.py | 28 +- api/fields/member_fields.py | 50 +-- api/fields/message_fields.py | 124 +++--- api/fields/segment_fields.py | 48 +- api/fields/tag_fields.py | 7 +- api/fields/workflow_app_log_fields.py | 18 +- api/fields/workflow_fields.py | 52 +-- api/fields/workflow_run_fields.py | 50 +-- api/pyproject.toml | 13 +- api/schedule/clean_embedding_cache_task.py | 21 +- api/schedule/clean_unused_datasets_task.py | 90 ++-- dev/reformat | 3 + 61 files changed, 1315 insertions(+), 1335 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index f6092c8633..d681dc6627 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -45,6 +45,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example + - name: Ruff formatter check + if: steps.changed-files.outputs.any_changed == 'true' + run: poetry run -C api ruff format --check ./api + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/api/app.py b/api/app.py index 50441cb81d..ad219ca0d6 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ import os -if os.environ.get("DEBUG", "false").lower() != 'true': +if os.environ.get("DEBUG", "false").lower() != "true": from gevent import monkey monkey.patch_all() @@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning) if os.name == "nt": os.system('tzutil /s "UTC"') else: - os.environ['TZ'] = 'UTC' + os.environ["TZ"] = "UTC" time.tzset() @@ -70,13 +70,14 @@ class DifyApp(Flask): # ------------- -config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first +config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first # ---------------------------- # Application Factory Function # ---------------------------- + def create_flask_app_with_configs() -> Flask: """ create a raw flask app @@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask: elif isinstance(value, int | float | bool): os.environ[key] = str(value) elif value is None: - os.environ[key] = '' + os.environ[key] = "" return dify_app @@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask: def create_app() -> Flask: app = create_flask_app_with_configs() - app.secret_key = app.config['SECRET_KEY'] + app.secret_key = app.config["SECRET_KEY"] log_handlers = None - log_file = app.config.get('LOG_FILE') + log_file = app.config.get("LOG_FILE") if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) @@ -111,23 +112,24 @@ def create_app() -> Flask: RotatingFileHandler( filename=log_file, maxBytes=1024 * 1024 * 1024, - backupCount=5 + backupCount=5, ), - logging.StreamHandler(sys.stdout) + logging.StreamHandler(sys.stdout), ] logging.basicConfig( - level=app.config.get('LOG_LEVEL'), - format=app.config.get('LOG_FORMAT'), - datefmt=app.config.get('LOG_DATEFORMAT'), + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), handlers=log_handlers, - force=True + force=True, ) - log_tz = app.config.get('LOG_TZ') + log_tz = app.config.get("LOG_TZ") if log_tz: from datetime import datetime import pytz + timezone = pytz.timezone(log_tz) def time_converter(seconds): @@ -162,24 +164,24 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ['console', 'inner_api']: + if request.blueprint not in ["console", "inner_api"]: return None # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get('Authorization', '') + auth_header = request.headers.get("Authorization", "") if not auth_header: - auth_token = request.args.get('_token') + auth_token = request.args.get("_token") if not auth_token: - raise Unauthorized('Invalid Authorization token.') + raise Unauthorized("Invalid Authorization token.") else: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) - user_id = decoded.get('user_id') + user_id = decoded.get("user_id") account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) if account: @@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login): @login_manager.unauthorized_handler def unauthorized_handler(): """Handle unauthorized requests.""" - return Response(json.dumps({ - 'code': 'unauthorized', - 'message': "Unauthorized." - }), status=401, content_type="application/json") + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) # register blueprint routers @@ -204,38 +207,36 @@ def register_blueprints(app): from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp - CORS(service_api_bp, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(service_api_bp) - CORS(web_bp, - resources={ - r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(web_bp) - CORS(console_app_bp, - resources={ - r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(console_app_bp) - CORS(files_bp, - allow_headers=['Content-Type'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) @@ -245,29 +246,29 @@ def register_blueprints(app): app = create_app() celery = app.extensions["celery"] -if app.config.get('TESTING'): +if app.config.get("TESTING"): print("App is running in TESTING mode") @app.after_request def after_request(response): """Add Version headers to the response.""" - response.set_cookie('remember_token', '', expires=0) - response.headers.add('X-Version', app.config['CURRENT_VERSION']) - response.headers.add('X-Env', app.config['DEPLOY_ENV']) + response.set_cookie("remember_token", "", expires=0) + response.headers.add("X-Version", app.config["CURRENT_VERSION"]) + response.headers.add("X-Env", app.config["DEPLOY_ENV"]) return response -@app.route('/health') +@app.route("/health") def health(): - return Response(json.dumps({ - 'pid': os.getpid(), - 'status': 'ok', - 'version': app.config['CURRENT_VERSION'] - }), status=200, content_type="application/json") + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + status=200, + content_type="application/json", + ) -@app.route('/threads') +@app.route("/threads") def threads(): num_threads = threading.active_count() threads = threading.enumerate() @@ -278,32 +279,34 @@ def threads(): thread_id = thread.ident is_alive = thread.is_alive() - thread_list.append({ - 'name': thread_name, - 'id': thread_id, - 'is_alive': is_alive - }) + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) return { - 'pid': os.getpid(), - 'thread_num': num_threads, - 'threads': thread_list + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, } -@app.route('/db-pool-stat') +@app.route("/db-pool-stat") def pool_stat(): engine = db.engine return { - 'pid': os.getpid(), - 'pool_size': engine.pool.size(), - 'checked_in_connections': engine.pool.checkedin(), - 'checked_out_connections': engine.pool.checkedout(), - 'overflow_connections': engine.pool.overflow(), - 'connection_timeout': engine.pool.timeout(), - 'recycle_time': db.engine.pool._recycle + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, } -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5001) +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5001) diff --git a/api/commands.py b/api/commands.py index 82a32f0f5b..41f1a6444c 100644 --- a/api/commands.py +++ b/api/commands.py @@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel from services.account_service import RegisterService, TenantService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) return # generate password salt @@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations! Password has been reset.', fg='green')) + click.echo(click.style("Congratulations! Password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo( - click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) -@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' - 'After the reset, all LLM credentials will become invalid, ' - 'requiring re-entry.' - 'Only support SELF_HOSTED mode.') -@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' - ' this operation cannot be rolled back!', fg='red')) +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + ) +) def reset_encrypt_key_pair(): """ Reset the encrypted key pair of workspace for encrypt LLM credentials. After the reset, all LLM credentials will become invalid, requiring re-entry. Only support SELF_HOSTED mode. """ - if dify_config.EDITION != 'SELF_HOSTED': - click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) + click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() - click.echo(click.style('Congratulations! ' - 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) + click.echo( + click.style( + "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + fg="green", + ) + ) -@click.command('vdb-migrate', help='migrate vector db.') -@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') +@click.command("vdb-migrate", help="migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ['knowledge', 'all']: + if scope in ["knowledge", "all"]: migrate_knowledge_vector_database() - if scope in ['annotation', 'all']: + if scope in ["annotation", "all"]: migrate_annotation_vector_database() @@ -146,7 +150,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style('Start migrate annotation data.', fg='green')) + click.echo(click.style("Start migrate annotation data.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -154,98 +158,103 @@ def migrate_annotation_vector_database(): while True: try: # get apps info - apps = db.session.query(App).filter( - App.status == 'normal' - ).order_by(App.created_at.desc()).paginate(page=page, per_page=50) + apps = ( + db.session.query(App) + .filter(App.status == "normal") + .order_by(App.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for app in apps: total_count = total_count + 1 - click.echo(f'Processing the {total_count} app {app.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create app annotation index: {}'.format(app.id)) - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app.id - ).first() + click.echo("Create app annotation index: {}".format(app.id)) + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo('App annotation setting is disabled: {}'.format(app.id)) + click.echo("App annotation setting is disabled: {}".format(app.id)) continue # get dataset_collection_binding info - dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( - DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id - ).first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) if not dataset_collection_binding: - click.echo('App annotation collection binding is not exist: {}'.format(app.id)) + click.echo("App annotation collection binding is not exist: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) documents = [] if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app.id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Start to migrate annotation, app_id: {app.id}.") try: vector.delete() - click.echo( - click.style(f'Successfully delete vector index for app: {app.id}.', - fg='green')) + click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green")) except Exception as e: - click.echo( - click.style(f'Failed to delete vector index for app {app.id}.', - fg='red')) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} annotations for app {app.id}.', - fg='green')) - vector.create(documents) click.echo( - click.style(f'Successfully created vector index for app {app.id}.', fg='green')) + click.style( + f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green")) except Exception as e: - click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) raise e - click.echo(f'Successfully migrated app annotation {app.id}.') + click.echo(f"Successfully migrated app annotation {app.id}.") create_count += 1 except Exception as e: click.echo( - click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red" + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + fg="green", + ) + ) def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style('Start migrate vector db.', fg='green')) + click.echo(click.style("Start migrate vector db.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -253,87 +262,77 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + datasets = ( + db.session.query(Dataset) + .filter(Dataset.indexing_technique == "high_quality") + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for dataset in datasets: total_count = total_count + 1 - click.echo(f'Processing the {total_count} dataset {dataset.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} dataset {dataset.id}. " + + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create dataset vdb index: {}'.format(dataset.id)) + click.echo("Create dataset vdb index: {}".format(dataset.id)) if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] == vector_type: + if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue - collection_name = '' + collection_name = "" if vector_type == VectorType.WEAVIATE: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.WEAVIATE, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.QDRANT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.MILVUS: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.MILVUS, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.RELYT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'relyt', - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.TENCENT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.TENCENT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.PGVECTOR: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.PGVECTOR, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.OPENSEARCH: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ANALYTICDB: @@ -341,16 +340,13 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ELASTICSEARCH: dataset_id = dataset.id index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'elasticsearch', - "vector_store": {"class_prefix": index_name} - } + index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -361,29 +357,41 @@ def migrate_knowledge_vector_database(): try: vector.delete() click.echo( - click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', - fg='green')) + click.style( + f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green" + ) + ) except Exception as e: click.echo( - click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', - fg='red')) + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) raise e - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .all() + ) for segment in segments: document = Document( @@ -393,7 +401,7 @@ def migrate_knowledge_vector_database(): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -401,37 +409,43 @@ def migrate_knowledge_vector_database(): if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) + click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green") + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e db.session.add(dataset) db.session.commit() - click.echo(f'Successfully migrated dataset {dataset.id}.') + click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 except Exception as e: db.session.rollback() click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green" + ) + ) -@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style('Start convert to agent apps.', fg='green')) + click.echo(click.style("Start convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -466,7 +480,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo('Converting app: {}'.format(app.id)) + click.echo("Converting app: {}".format(app.id)) try: app.mode = AppMode.AGENT_CHAT.value @@ -478,137 +492,139 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + click.echo(click.style("Converted app: {}".format(app.id), fg="green")) except Exception as e: - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) + click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) - click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) -@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') -@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') +@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.") def add_qdrant_doc_id_index(field: str): - click.echo(click.style('Start add qdrant doc_id index.', fg='green')) + click.echo(click.style("Start add qdrant doc_id index.", fg="green")) vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": - click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) + click.echo(click.style("Sorry, only support qdrant vector store.", fg="red")) return create_count = 0 try: bindings = db.session.query(DatasetCollectionBinding).all() if not bindings: - click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) + click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red")) return import qdrant_client from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + for binding in bindings: if dify_config.QDRANT_URL is None: - raise ValueError('Qdrant url is required.') + raise ValueError("Qdrant url is required.") qdrant_config = QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, root_path=current_app.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) # create payload index - client.create_payload_index(binding.collection_name, field, - field_schema=PayloadSchemaType.KEYWORD) + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 except UnexpectedResponse as e: # Collection does not exist, so return if e.status_code == 404: - click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red") + ) continue # Some other error occurred, so re-raise the exception else: - click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style( + f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red" + ) + ) except Exception as e: - click.echo(click.style('Failed to create qdrant client.', fg='red')) + click.echo(click.style("Failed to create qdrant client.", fg="red")) - click.echo( - click.style(f'Congratulations! Create {create_count} collection indexes.', - fg='green')) + click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green")) -@click.command('create-tenant', help='Create account and tenant.') -@click.option('--email', prompt=True, help='The email address of the tenant account.') -@click.option('--language', prompt=True, help='Account language, default: en-US.') +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") def create_tenant(email: str, language: Optional[str] = None): """ Create tenant account """ if not email: - click.echo(click.style('Sorry, email is required.', fg='red')) + click.echo(click.style("Sorry, email is required.", fg="red")) return # Create account email = email.strip() - if '@' not in email: - click.echo(click.style('Sorry, invalid email address.', fg='red')) + if "@" not in email: + click.echo(click.style("Sorry, invalid email address.", fg="red")) return - account_name = email.split('@')[0] + account_name = email.split("@")[0] if language not in languages: - language = 'en-US' + language = "en-US" # generate random password new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language - ) + account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) TenantService.create_owner_tenant_if_not_exist(account) - click.echo(click.style('Congratulations! Account and tenant created.\n' - 'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) + click.echo( + click.style( + "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + fg="green", + ) + ) -@click.command('upgrade-db', help='upgrade the database') +@click.command("upgrade-db", help="upgrade the database") def upgrade_db(): - click.echo('Preparing database migration...') - lock = redis_client.lock(name='db_upgrade_lock', timeout=60) + click.echo("Preparing database migration...") + lock = redis_client.lock(name="db_upgrade_lock", timeout=60) if lock.acquire(blocking=False): try: - click.echo(click.style('Start database migration.', fg='green')) + click.echo(click.style("Start database migration.", fg="green")) # run db migration import flask_migrate + flask_migrate.upgrade() - click.echo(click.style('Database migration successful!', fg='green')) + click.echo(click.style("Database migration successful!", fg="green")) except Exception as e: - logging.exception(f'Database migration failed, error: {e}') + logging.exception(f"Database migration failed, error: {e}") finally: lock.release() else: - click.echo('Database migration skipped') + click.echo("Database migration skipped") -@click.command('fix-app-site-missing', help='Fix app related site missing issue.') +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") def fix_app_site_missing(): """ Fix app related site missing issue. """ - click.echo(click.style('Start fix app related site missing issue.', fg='green')) + click.echo(click.style("Start fix app related site missing issue.", fg="green")) failed_app_ids = [] while True: @@ -639,15 +655,14 @@ where sites.id is null limit 1000""" app_was_created.send(app, account=account) except Exception as e: failed_app_ids.append(app_id) - click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red')) - logging.exception(f'Fix app related site missing issue failed, error: {e}') + click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red")) + logging.exception(f"Fix app related site missing issue failed, error: {e}") continue if not processed_count: break - - click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green')) + click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green")) def register_commands(app): diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e374c04316..e22c3268ef 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1 +1 @@ -HIDDEN_VALUE = '[__HIDDEN__]' +HIDDEN_VALUE = "[__HIDDEN__]" diff --git a/api/constants/languages.py b/api/constants/languages.py index 38e49e0d1e..524dc61b57 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,22 +1,22 @@ language_timezone_mapping = { - 'en-US': 'America/New_York', - 'zh-Hans': 'Asia/Shanghai', - 'zh-Hant': 'Asia/Taipei', - 'pt-BR': 'America/Sao_Paulo', - 'es-ES': 'Europe/Madrid', - 'fr-FR': 'Europe/Paris', - 'de-DE': 'Europe/Berlin', - 'ja-JP': 'Asia/Tokyo', - 'ko-KR': 'Asia/Seoul', - 'ru-RU': 'Europe/Moscow', - 'it-IT': 'Europe/Rome', - 'uk-UA': 'Europe/Kyiv', - 'vi-VN': 'Asia/Ho_Chi_Minh', - 'ro-RO': 'Europe/Bucharest', - 'pl-PL': 'Europe/Warsaw', - 'hi-IN': 'Asia/Kolkata', - 'tr-TR': 'Europe/Istanbul', - 'fa-IR': 'Asia/Tehran', + "en-US": "America/New_York", + "zh-Hans": "Asia/Shanghai", + "zh-Hant": "Asia/Taipei", + "pt-BR": "America/Sao_Paulo", + "es-ES": "Europe/Madrid", + "fr-FR": "Europe/Paris", + "de-DE": "Europe/Berlin", + "ja-JP": "Asia/Tokyo", + "ko-KR": "Asia/Seoul", + "ru-RU": "Europe/Moscow", + "it-IT": "Europe/Rome", + "uk-UA": "Europe/Kyiv", + "vi-VN": "Asia/Ho_Chi_Minh", + "ro-RO": "Europe/Bucharest", + "pl-PL": "Europe/Warsaw", + "hi-IN": "Asia/Kolkata", + "tr-TR": "Europe/Istanbul", + "fa-IR": "Asia/Tehran", } languages = list(language_timezone_mapping.keys()) @@ -26,6 +26,5 @@ def supported_language(lang): if lang in languages: return lang - error = ('{lang} is not a valid language.' - .format(lang=lang)) + error = "{lang} is not a valid language.".format(lang=lang) raise ValueError(error) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index cc5a370254..7e1a196356 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -5,82 +5,79 @@ from models.model import AppMode default_app_templates = { # workflow default mode AppMode.WORKFLOW: { - 'app': { - 'mode': AppMode.WORKFLOW.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.WORKFLOW.value, + "enable_site": True, + "enable_api": True, } }, - # completion default mode AppMode.COMPLETION: { - 'app': { - 'mode': AppMode.COMPLETION.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.COMPLETION.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} + "completion_params": {}, }, - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + "user_input_form": json.dumps( + [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "", + }, + }, + ] + ), + "pre_prompt": "{{query}}", }, - }, - # chat default mode AppMode.CHAT: { - 'app': { - 'mode': AppMode.CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } + "completion_params": {}, + }, + }, }, - # advanced-chat default mode AppMode.ADVANCED_CHAT: { - 'app': { - 'mode': AppMode.ADVANCED_CHAT.value, - 'enable_site': True, - 'enable_api': True - } + "app": { + "mode": AppMode.ADVANCED_CHAT.value, + "enable_site": True, + "enable_api": True, + }, }, - # agent-chat default mode AppMode.AGENT_CHAT: { - 'app': { - 'mode': AppMode.AGENT_CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.AGENT_CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } - } + "completion_params": {}, + }, + }, + }, } diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index b6b18f5c5b..623a1a28eb 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -2,6 +2,6 @@ from contextvars import ContextVar from core.workflow.entities.variable_pool import VariablePool -tenant_id: ContextVar[str] = ContextVar('tenant_id') +tenant_id: ContextVar[str] = ContextVar("tenant_id") -workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') +workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") diff --git a/api/events/app_event.py b/api/events/app_event.py index 67a5982527..f2ce71bbbb 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -1,13 +1,13 @@ from blinker import signal # sender: app -app_was_created = signal('app-was-created') +app_was_created = signal("app-was-created") # sender: app, kwargs: app_model_config -app_model_config_was_updated = signal('app-model-config-was-updated') +app_model_config_was_updated = signal("app-model-config-was-updated") # sender: app, kwargs: published_workflow -app_published_workflow_was_updated = signal('app-published-workflow-was-updated') +app_published_workflow_was_updated = signal("app-published-workflow-was-updated") # sender: app, kwargs: synced_draft_workflow -app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') +app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py index d4a2b6f313..750b7424e2 100644 --- a/api/events/dataset_event.py +++ b/api/events/dataset_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: dataset -dataset_was_deleted = signal('dataset-was-deleted') +dataset_was_deleted = signal("dataset-was-deleted") diff --git a/api/events/document_event.py b/api/events/document_event.py index f95326630b..2c5a416a5e 100644 --- a/api/events/document_event.py +++ b/api/events/document_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_was_deleted = signal('document-was-deleted') +document_was_deleted = signal("document-was-deleted") diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 42f1c70614..7caa2d1cc9 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect def handle(sender, **kwargs): dataset = sender - clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) + clean_dataset_task.delay( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + ) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 24022da15f..00a66f50ad 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task @document_was_deleted.connect def handle(sender, **kwargs): document_id = sender - dataset_id = kwargs.get('dataset_id') - doc_form = kwargs.get('doc_form') - file_id = kwargs.get('file_id') + dataset_id = kwargs.get("dataset_id") + doc_form = kwargs.get("doc_form") + file_id = kwargs.get("file_id") clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 68dae5a553..72a135e73d 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,21 +14,25 @@ from models.dataset import Document @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get('document_ids', None) + document_ids = kwargs.get("document_ids", None) documents = [] start_at = time.perf_counter() for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document) + .filter( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) + .first() + ) if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -38,8 +42,8 @@ def handle(sender, **kwargs): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 31084ce0fe..57412cc4ad 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -10,7 +10,7 @@ def handle(sender, **kwargs): installed_app = InstalledApp( tenant_id=app.tenant_id, app_id=app.id, - app_owner_tenant_id=app.tenant_id + app_owner_tenant_id=app.tenant_id, ) db.session.add(installed_app) db.session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index f0eb7159b6..abaf0e41ec 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -7,15 +7,15 @@ from models.model import Site def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender - account = kwargs.get('account') + account = kwargs.get("account") site = Site( app_id=app.id, title=app.name, - icon = app.icon, - icon_background = app.icon_background, + icon=app.icon, + icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) + customize_token_strategy="not_allow", + code=Site.generate_code(16), ) db.session.add(site) diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8cf52bf8f5..843a232096 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return @@ -39,7 +39,7 @@ def handle(sender, **kwargs): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_config.model: + if "gpt-4" in model_config.model: used_quota = 20 else: used_quota = 1 @@ -50,6 +50,6 @@ def handle(sender, **kwargs): Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1f6da34ee2..f96bb5ef74 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): - if node_data.get('data', {}).get('type') == NodeType.TOOL.value: + for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( @@ -23,7 +23,7 @@ def handle(sender, **kwargs): tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}', ) manager.delete_tool_parameters_cache() except: diff --git a/api/events/event_handlers/document_index_event.py b/api/events/event_handlers/document_index_event.py index 9c4e055deb..3d463fe5b3 100644 --- a/api/events/event_handlers/document_index_event.py +++ b/api/events/event_handlers/document_index_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_index_created = signal('document-index-created') +document_index_created = signal("document-index-created") diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 2b202c53d0..59375b1a0b 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -7,13 +7,11 @@ from models.model import AppModelConfig @app_model_config_was_updated.connect def handle(sender, **kwargs): app = sender - app_model_config = kwargs.get('app_model_config') + app_model_config = kwargs.get("app_model_config") dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -29,16 +27,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) or [] + tools = agent_mode.get("tools", []) or [] for tool in tools: if len(list(tool.keys())) != 1: continue @@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict - datasets = dataset_configs.get('datasets', {}) or {} - for dataset in datasets.get('datasets', []) or []: + datasets = dataset_configs.get("datasets", {}) or {} + for dataset in datasets.get("datasets", []) or []: keys = list(dataset.keys()) - if len(keys) == 1 and keys[0] == 'dataset': - if dataset['dataset'].get('id'): - dataset_ids.add(dataset['dataset'].get('id')) + if len(keys) == 1 and keys[0] == "dataset": + if dataset["dataset"].get("id"): + dataset_ids.add(dataset["dataset"].get("id")) return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 996b1e9691..333b85ecb2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -11,13 +11,11 @@ from models.workflow import Workflow @app_published_workflow_was_updated.connect def handle(sender, **kwargs): app = sender - published_workflow = kwargs.get('published_workflow') + published_workflow = kwargs.get("published_workflow") published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -33,16 +31,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: if not graph: return dataset_ids - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) # fetch all knowledge retrieval nodes - knowledge_retrieval_nodes = [node for node in nodes - if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] + knowledge_retrieval_nodes = [ + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + ] if not knowledge_retrieval_nodes: return dataset_ids for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) + node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) dataset_ids.update(node_data.dataset_ids) except Exception as e: continue diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 6188f1a085..a80572c0de 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -9,13 +9,13 @@ from models.provider import Provider @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, - Provider.provider_name == application_generate_entity.model_conf.provider - ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) + Provider.provider_name == application_generate_entity.model_conf.provider, + ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) db.session.commit() diff --git a/api/events/message_event.py b/api/events/message_event.py index 21da83f249..6576c35c45 100644 --- a/api/events/message_event.py +++ b/api/events/message_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: message, kwargs: conversation -message_was_created = signal('message-was-created') +message_was_created = signal("message-was-created") diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py index 942f709917..d99feaac40 100644 --- a/api/events/tenant_event.py +++ b/api/events/tenant_event.py @@ -1,7 +1,7 @@ from blinker import signal # sender: tenant -tenant_was_created = signal('tenant-was-created') +tenant_was_created = signal("tenant-was-created") # sender: tenant -tenant_was_updated = signal('tenant-was-updated') +tenant_was_updated = signal("tenant-was-updated") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index ae9a075340..f5ec7c1759 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -17,7 +17,7 @@ def init_app(app: Flask) -> Celery: backend=app.config["CELERY_BACKEND"], task_ignore_result=True, ) - + # Add SSL options to the Celery configuration ssl_options = { "ssl_cert_reqs": None, @@ -35,7 +35,7 @@ def init_app(app: Flask) -> Celery: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) - + celery_app.set_default() app.extensions["celery"] = celery_app @@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery: ] day = app.config["CELERY_BEAT_SCHEDULER_TIME"] beat_schedule = { - 'clean_embedding_cache_task': { - 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=day), + "clean_embedding_cache_task": { + "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", + "schedule": timedelta(days=day), + }, + "clean_unused_datasets_task": { + "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", + "schedule": timedelta(days=day), }, - 'clean_unused_datasets_task': { - 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(days=day), - } } - celery_app.conf.update( - beat_schedule=beat_schedule, - imports=imports - ) + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 1dbaffcfb0..38e67749fc 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -2,15 +2,14 @@ from flask import Flask def init_app(app: Flask): - if app.config.get('API_COMPRESSION_ENABLED'): + if app.config.get("API_COMPRESSION_ENABLED"): from flask_compress import Compress - app.config['COMPRESS_MIMETYPES'] = [ - 'application/json', - 'image/svg+xml', - 'text/html', + app.config["COMPRESS_MIMETYPES"] = [ + "application/json", + "image/svg+xml", + "text/html", ] compress = Compress() compress.init_app(app) - diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c248e173a2..f6ffa53634 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy import MetaData POSTGRES_INDEXES_NAMING_CONVENTION = { - 'ix': '%(column_0_label)s_idx', - 'uq': '%(table_name)s_%(column_0_name)s_key', - 'ck': '%(table_name)s_%(constraint_name)s_check', - 'fk': '%(table_name)s_%(column_0_name)s_fkey', - 'pk': '%(table_name)s_pkey', + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index ec3a5cc112..b435294abc 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -14,67 +14,69 @@ class Mail: return self._client is not None def init_app(self, app: Flask): - if app.config.get('MAIL_TYPE'): - if app.config.get('MAIL_DEFAULT_SEND_FROM'): - self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') - - if app.config.get('MAIL_TYPE') == 'resend': - api_key = app.config.get('RESEND_API_KEY') - if not api_key: - raise ValueError('RESEND_API_KEY is not set') + if app.config.get("MAIL_TYPE"): + if app.config.get("MAIL_DEFAULT_SEND_FROM"): + self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") - api_url = app.config.get('RESEND_API_URL') + if app.config.get("MAIL_TYPE") == "resend": + api_key = app.config.get("RESEND_API_KEY") + if not api_key: + raise ValueError("RESEND_API_KEY is not set") + + api_url = app.config.get("RESEND_API_URL") if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get('MAIL_TYPE') == 'smtp': + elif app.config.get("MAIL_TYPE") == "smtp": from libs.smtp import SMTPClient - if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): - raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') - if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): - raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') + + if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") + if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get('SMTP_SERVER'), - port=app.config.get('SMTP_PORT'), - username=app.config.get('SMTP_USERNAME'), - password=app.config.get('SMTP_PASSWORD'), - _from=app.config.get('MAIL_DEFAULT_SEND_FROM'), - use_tls=app.config.get('SMTP_USE_TLS'), - opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') + server=app.config.get("SMTP_SERVER"), + port=app.config.get("SMTP_PORT"), + username=app.config.get("SMTP_USERNAME"), + password=app.config.get("SMTP_PASSWORD"), + _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), + use_tls=app.config.get("SMTP_USE_TLS"), + opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), ) else: - raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) + raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) else: - logging.warning('MAIL_TYPE is not set') - + logging.warning("MAIL_TYPE is not set") def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: - raise ValueError('Mail client is not initialized') + raise ValueError("Mail client is not initialized") if not from_ and self._default_send_from: from_ = self._default_send_from if not from_: - raise ValueError('mail from is not set') + raise ValueError("mail from is not set") if not to: - raise ValueError('mail to is not set') + raise ValueError("mail to is not set") if not subject: - raise ValueError('mail subject is not set') + raise ValueError("mail subject is not set") if not html: - raise ValueError('mail html is not set') + raise ValueError("mail html is not set") - self._client.send({ - "from": from_, - "to": to, - "subject": subject, - "html": html - }) + self._client.send( + { + "from": from_, + "to": to, + "subject": subject, + "html": html, + } + ) def init_app(app: Flask): diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 23d7768d4d..d5fb162fd8 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -6,18 +6,21 @@ redis_client = redis.Redis() def init_app(app): connection_class = Connection - if app.config.get('REDIS_USE_SSL'): + if app.config.get("REDIS_USE_SSL"): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool(**{ - 'host': app.config.get('REDIS_HOST'), - 'port': app.config.get('REDIS_PORT'), - 'username': app.config.get('REDIS_USERNAME'), - 'password': app.config.get('REDIS_PASSWORD'), - 'db': app.config.get('REDIS_DB'), - 'encoding': 'utf-8', - 'encoding_errors': 'strict', - 'decode_responses': False - }, connection_class=connection_class) + redis_client.connection_pool = redis.ConnectionPool( + **{ + "host": app.config.get("REDIS_HOST"), + "port": app.config.get("REDIS_PORT"), + "username": app.config.get("REDIS_USERNAME"), + "password": app.config.get("REDIS_PASSWORD"), + "db": app.config.get("REDIS_DB"), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + }, + connection_class=connection_class, + ) - app.extensions['redis'] = redis_client + app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index f05c10bc08..227c6635f0 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException def init_app(app): - if app.config.get('SENTRY_DSN'): + if app.config.get("SENTRY_DSN"): sentry_sdk.init( - dsn=app.config.get('SENTRY_DSN'), - integrations=[ - FlaskIntegration(), - CeleryIntegration() - ], + dsn=app.config.get("SENTRY_DSN"), + integrations=[FlaskIntegration(), CeleryIntegration()], ignore_errors=[HTTPException, ValueError], - traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), - profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), - environment=app.config.get('DEPLOY_ENV'), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" + traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), + profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), + environment=app.config.get("DEPLOY_ENV"), + release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 38db1c6ce1..e6c4352577 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -17,31 +17,19 @@ class Storage: self.storage_runner = None def init_app(self, app: Flask): - storage_type = app.config.get('STORAGE_TYPE') - if storage_type == 's3': - self.storage_runner = S3Storage( - app=app - ) - elif storage_type == 'azure-blob': - self.storage_runner = AzureStorage( - app=app - ) - elif storage_type == 'aliyun-oss': - self.storage_runner = AliyunStorage( - app=app - ) - elif storage_type == 'google-storage': - self.storage_runner = GoogleStorage( - app=app - ) - elif storage_type == 'tencent-cos': - self.storage_runner = TencentStorage( - app=app - ) - elif storage_type == 'oci-storage': - self.storage_runner = OCIStorage( - app=app - ) + storage_type = app.config.get("STORAGE_TYPE") + if storage_type == "s3": + self.storage_runner = S3Storage(app=app) + elif storage_type == "azure-blob": + self.storage_runner = AzureStorage(app=app) + elif storage_type == "aliyun-oss": + self.storage_runner = AliyunStorage(app=app) + elif storage_type == "google-storage": + self.storage_runner = GoogleStorage(app=app) + elif storage_type == "tencent-cos": + self.storage_runner = TencentStorage(app=app) + elif storage_type == "oci-storage": + self.storage_runner = OCIStorage(app=app) else: self.storage_runner = LocalStorage(app=app) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b81a8691f1..b962cedc55 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -8,23 +8,22 @@ from extensions.storage.base_storage import BaseStorage class AliyunStorage(BaseStorage): - """Implementation for aliyun storage. - """ + """Implementation for aliyun storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME') + self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") oss_auth_method = aliyun_s3.Auth region = None - if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4': + if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get('ALIYUN_OSS_REGION') - oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')) + region = app_config.get("ALIYUN_OSS_REGION") + oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) self.client = aliyun_s3.Bucket( oss_auth, - app_config.get('ALIYUN_OSS_ENDPOINT'), + app_config.get("ALIYUN_OSS_ENDPOINT"), self.bucket_name, connect_timeout=30, region=region, diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py index af3e7ef849..ca8cbb9188 100644 --- a/api/extensions/storage/azure_storage.py +++ b/api/extensions/storage/azure_storage.py @@ -9,16 +9,15 @@ from extensions.storage.base_storage import BaseStorage class AzureStorage(BaseStorage): - """Implementation for azure storage. - """ + """Implementation for azure storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') - self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') - self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') - self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') + self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") + self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") + self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") + self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") def save(self, filename, data): client = self._sync_client() @@ -39,6 +38,7 @@ class AzureStorage(BaseStorage): blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() + return generate(filename) def download(self, filename, target_filepath): @@ -62,17 +62,17 @@ class AzureStorage(BaseStorage): blob_container.delete_blob(filename) def _sync_client(self): - cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) cache_result = redis_client.get(cache_key) if cache_result is not None: - sas_token = cache_result.decode('utf-8') + sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( account_name=self.account_name, account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 13d9c34290..c3fe9ec82a 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -1,4 +1,5 @@ """Abstract interface for file storage implementations.""" + from abc import ABC, abstractmethod from collections.abc import Generator @@ -6,8 +7,8 @@ from flask import Flask class BaseStorage(ABC): - """Interface for file storage. - """ + """Interface for file storage.""" + app = None def __init__(self, app: Flask): diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index ef6cd69039..9ed1fcf0b4 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -11,16 +11,16 @@ from extensions.storage.base_storage import BaseStorage class GoogleStorage(BaseStorage): - """Implementation for google storage. - """ + """Implementation for google storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') - service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') + self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") + service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") # if service_account_json_str is empty, use Application Default Credentials if service_account_json_str: - service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) @@ -43,9 +43,10 @@ class GoogleStorage(BaseStorage): def generate(filename: str = filename) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - with closing(blob.open(mode='rb')) as blob_stream: + with closing(blob.open(mode="rb")) as blob_stream: while chunk := blob_stream.read(4096): yield chunk + return generate() def download(self, filename, target_filepath): @@ -60,4 +61,4 @@ class GoogleStorage(BaseStorage): def delete(self, filename): bucket = self.client.get_bucket(self.bucket_name) - bucket.delete_blob(filename) \ No newline at end of file + bucket.delete_blob(filename) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 389ef12f82..46ee4bf80f 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -8,21 +8,20 @@ from extensions.storage.base_storage import BaseStorage class LocalStorage(BaseStorage): - """Implementation for local storage. - """ + """Implementation for local storage.""" def __init__(self, app: Flask): super().__init__(app) - folder = self.app.config.get('STORAGE_LOCAL_PATH') + folder = self.app.config.get("STORAGE_LOCAL_PATH") if not os.path.isabs(folder): folder = os.path.join(app.root_path, folder) self.folder = folder def save(self, filename, data): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) @@ -31,10 +30,10 @@ class LocalStorage(BaseStorage): f.write(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -46,10 +45,10 @@ class LocalStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -61,10 +60,10 @@ class LocalStorage(BaseStorage): return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -72,17 +71,17 @@ class LocalStorage(BaseStorage): shutil.copyfile(filename, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename return os.path.exists(filename) def delete(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if os.path.exists(filename): os.remove(filename) diff --git a/api/extensions/storage/oci_storage.py b/api/extensions/storage/oci_storage.py index e78d870950..e32fa0a0ae 100644 --- a/api/extensions/storage/oci_storage.py +++ b/api/extensions/storage/oci_storage.py @@ -12,14 +12,14 @@ class OCIStorage(BaseStorage): def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('OCI_BUCKET_NAME') + self.bucket_name = app_config.get("OCI_BUCKET_NAME") self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('OCI_SECRET_KEY'), - aws_access_key_id=app_config.get('OCI_ACCESS_KEY'), - endpoint_url=app_config.get('OCI_ENDPOINT'), - region_name=app_config.get('OCI_REGION') - ) + "s3", + aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), + aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), + endpoint_url=app_config.get("OCI_ENDPOINT"), + region_name=app_config.get("OCI_REGION"), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -27,9 +27,9 @@ class OCIStorage(BaseStorage): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,12 +40,13 @@ class OCIStorage(BaseStorage): try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): @@ -61,4 +62,4 @@ class OCIStorage(BaseStorage): return False def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) \ No newline at end of file + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 787596fa79..022ce5b14a 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -10,24 +10,24 @@ from extensions.storage.base_storage import BaseStorage class S3Storage(BaseStorage): - """Implementation for s3 storage. - """ + """Implementation for s3 storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('S3_BUCKET_NAME') - if app_config.get('S3_USE_AWS_MANAGED_IAM'): + self.bucket_name = app_config.get("S3_BUCKET_NAME") + if app_config.get("S3_USE_AWS_MANAGED_IAM"): session = boto3.Session() - self.client = session.client('s3') + self.client = session.client("s3") else: self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('S3_SECRET_KEY'), - aws_access_key_id=app_config.get('S3_ACCESS_KEY'), - endpoint_url=app_config.get('S3_ENDPOINT'), - region_name=app_config.get('S3_REGION'), - config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')}) - ) + "s3", + aws_secret_access_key=app_config.get("S3_SECRET_KEY"), + aws_access_key_id=app_config.get("S3_ACCESS_KEY"), + endpoint_url=app_config.get("S3_ENDPOINT"), + region_name=app_config.get("S3_REGION"), + config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -35,9 +35,9 @@ class S3Storage(BaseStorage): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -48,12 +48,13 @@ class S3Storage(BaseStorage): try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index e2c1ca55e3..1d499cd3bc 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -7,18 +7,17 @@ from extensions.storage.base_storage import BaseStorage class TencentStorage(BaseStorage): - """Implementation for tencent cos storage. - """ + """Implementation for tencent cos storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME') + self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") config = CosConfig( - Region=app_config.get('TENCENT_COS_REGION'), - SecretId=app_config.get('TENCENT_COS_SECRET_ID'), - SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'), - Scheme=app_config.get('TENCENT_COS_SCHEME'), + Region=app_config.get("TENCENT_COS_REGION"), + SecretId=app_config.get("TENCENT_COS_SECRET_ID"), + SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), + Scheme=app_config.get("TENCENT_COS_SCHEME"), ) self.client = CosS3Client(config) @@ -26,19 +25,19 @@ class TencentStorage(BaseStorage): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].get_stream(chunk_size=4096) + yield from response["Body"].get_stream(chunk_size=4096) return generate() def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - response['Body'].get_stream_to_file(target_filepath) + response["Body"].get_stream_to_file(target_filepath) def exists(self, filename): return self.client.object_exists(Bucket=self.bucket_name, Key=filename) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index c778084475..379dcc6d16 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -5,7 +5,7 @@ from libs.helper import TimestampField annotation_fields = { "id": fields.String, "question": fields.String, - "answer": fields.Raw(attribute='content'), + "answer": fields.Raw(attribute="content"), "hit_count": fields.Integer, "created_at": TimestampField, # 'account': fields.Nested(simple_account_fields, allow_null=True) @@ -21,8 +21,8 @@ annotation_hit_history_fields = { "score": fields.Float, "question": fields.String, "created_at": TimestampField, - "match": fields.String(attribute='annotation_question'), - "response": fields.String(attribute='annotation_content') + "match": fields.String(attribute="annotation_question"), + "response": fields.String(attribute="annotation_content"), } annotation_hit_history_list_fields = { diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 749e9900de..a85d4a34db 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -8,16 +8,16 @@ class HiddenAPIKey(fields.Raw): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: - return api_key[0] + '******' + api_key[-1] + return api_key[0] + "******" + api_key[-1] # If the api_key is greater than 8 characters, show the first three and the last three characters else: - return api_key[:3] + '******' + api_key[-3:] + return api_key[:3] + "******" + api_key[-3:] api_based_extension_fields = { - 'id': fields.String, - 'name': fields.String, - 'api_endpoint': fields.String, - 'api_key': HiddenAPIKey, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "api_endpoint": fields.String, + "api_key": HiddenAPIKey, + "created_at": TimestampField, } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 94d804a919..7036d58e4a 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -3,157 +3,153 @@ from flask_restful import fields from libs.helper import TimestampField app_detail_kernel_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, } related_app_list = { - 'data': fields.List(fields.Nested(app_detail_kernel_fields)), - 'total': fields.Integer, + "data": fields.List(fields.Nested(app_detail_kernel_fields)), + "total": fields.Integer, } model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), - 'text_to_speech': fields.Raw(attribute='text_to_speech_dict'), - 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), - 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), - 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'dataset_query_variable': fields.String, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - 'prompt_type': fields.String, - 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), - 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), - 'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), - 'file_upload': fields.Raw(attribute='file_upload_dict'), - 'created_at': TimestampField + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "speech_to_text": fields.Raw(attribute="speech_to_text_dict"), + "text_to_speech": fields.Raw(attribute="text_to_speech_dict"), + "retriever_resource": fields.Raw(attribute="retriever_resource_dict"), + "annotation_reply": fields.Raw(attribute="annotation_reply_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"), + "external_data_tools": fields.Raw(attribute="external_data_tools_list"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "dataset_query_variable": fields.String, + "pre_prompt": fields.String, + "agent_mode": fields.Raw(attribute="agent_mode_dict"), + "prompt_type": fields.String, + "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"), + "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), + "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), + "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_at": TimestampField, } app_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'tracing': fields.Raw, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "tracing": fields.Raw, + "created_at": TimestampField, } prompt_config_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } model_config_partial_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} app_partial_fields = { - 'id': fields.String, - 'name': fields.String, - 'max_active_requests': fields.Raw(), - 'description': fields.String(attribute='desc_or_prompt'), - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), - 'created_at': TimestampField, - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "max_active_requests": fields.Raw(), + "description": fields.String(attribute="desc_or_prompt"), + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "created_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), } app_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), } template_fields = { - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'mode': fields.String, - 'model_config': fields.Nested(model_config_fields), + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, + "model_config": fields.Nested(model_config_fields), } template_list_fields = { - 'data': fields.List(fields.Nested(template_fields)), + "data": fields.List(fields.Nested(template_fields)), } site_fields = { - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'app_base_url': fields.String, - 'show_workflow_steps': fields.Boolean, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, } app_detail_fields_with_site = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'site': fields.Nested(site_fields), - 'api_base_url': fields.String, - 'created_at': TimestampField, - 'deleted_tools': fields.List(fields.String), + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "created_at": TimestampField, + "deleted_tools": fields.List(fields.String), } app_site_fields = { - 'app_id': fields.String, - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 79ceb02685..1b15fe3880 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -6,205 +6,202 @@ from libs.helper import TimestampField class MessageTextField(fields.Raw): def format(self, value): - return value[0]['text'] if value else '' + return value[0]["text"] if value else "" feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(simple_account_fields, allow_null=True), + "rating": fields.String, + "content": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account": fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { - 'id': fields.String, - 'question': fields.String, - 'content': fields.String, - 'account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "question": fields.String, + "content": fields.String, + "account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } annotation_hit_history_fields = { - 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "annotation_id": fields.String(attribute="id"), + "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } message_file_fields = { - 'id': fields.String, - 'type': fields.String, - 'url': fields.String, - 'belongs_to': fields.String(default='user'), + "id": fields.String, + "type": fields.String, + "url": fields.String, + "belongs_to": fields.String(default="user"), } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String), + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'workflow_run_id': fields.String, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'metadata': fields.Raw(attribute='message_metadata_dict'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": fields.Raw, + "message_tokens": fields.Integer, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "answer_tokens": fields.Integer, + "provider_response_latency": fields.Float, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "feedbacks": fields.List(fields.Nested(feedback_fields)), + "workflow_run_id": fields.String, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "metadata": fields.Raw(attribute="message_metadata_dict"), + "status": fields.String, + "error": fields.String, } -feedback_stat_fields = { - 'like': fields.Integer, - 'dislike': fields.Integer -} +feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'model': fields.Raw, - 'user_input_form': fields.Raw, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw, + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "model": fields.Raw, + "user_input_form": fields.Raw, + "pre_prompt": fields.String, + "agent_mode": fields.Raw, } simple_configs_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } simple_message_detail_fields = { - 'inputs': fields.Raw, - 'query': fields.String, - 'message': MessageTextField, - 'answer': fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": MessageTextField, + "answer": fields.String, } conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String(), - 'from_account_id': fields.String, - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'model_config': fields.Nested(simple_model_config_fields), - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields), - 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String(), + "from_account_id": fields.String, + "read_at": TimestampField, + "created_at": TimestampField, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "model_config": fields.Nested(simple_model_config_fields), + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), } conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_fields), attribute="items"), } conversation_message_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Nested(model_config_fields), - 'message': fields.Nested(message_detail_fields, attribute='first_message'), + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "model_config": fields.Nested(model_config_fields), + "message": fields.Nested(message_detail_fields, attribute="first_message"), } conversation_with_summary_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String, - 'from_account_id': fields.String, - 'name': fields.String, - 'summary': fields.String(attribute='summary_or_query'), - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(simple_model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String, + "from_account_id": fields.String, + "name": fields.String, + "summary": fields.String(attribute="summary_or_query"), + "read_at": TimestampField, + "created_at": TimestampField, + "annotated": fields.Boolean, + "model_config": fields.Nested(simple_model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } conversation_with_summary_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), } conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'introduction': fields.String, - 'model_config': fields.Nested(model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "annotated": fields.Boolean, + "introduction": fields.String, + "model_config": fields.Nested(model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } simple_conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "inputs": fields.Raw, + "status": fields.String, + "introduction": fields.String, + "created_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(simple_conversation_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(simple_conversation_fields)), } conversation_with_model_config_fields = { **simple_conversation_fields, - 'model_config': fields.Raw, + "model_config": fields.Raw, } conversation_with_model_config_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_with_model_config_fields)), } diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 782a848c1a..983e50e73c 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -3,19 +3,19 @@ from flask_restful import fields from libs.helper import TimestampField conversation_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value_type': fields.String(attribute='value_type.value'), - 'value': fields.String, - 'description': fields.String, - 'created_at': TimestampField, - 'updated_at': TimestampField, + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.String, + "description": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, } paginated_conversation_variable_fields = { - 'page': fields.Integer, - 'limit': fields.Integer, - 'total': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'), + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"), } diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85..071071376f 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -2,64 +2,56 @@ from flask_restful import fields from libs.helper import TimestampField -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'is_bound': fields.Boolean, - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)) + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), } integrate_notion_info_list_fields = { - 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), } -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)), - 'total': fields.Integer + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), + "total": fields.Integer, } integrate_fields = { - 'id': fields.String, - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'disabled': fields.Boolean, - 'link': fields.String, - 'source_info': fields.Nested(integrate_workspace_fields) + "id": fields.String, + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "disabled": fields.Boolean, + "link": fields.String, + "source_info": fields.Nested(integrate_workspace_fields), } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), -} \ No newline at end of file + "data": fields.List(fields.Nested(integrate_fields)), +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a9f79b5c67..9cf8da7acd 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -3,73 +3,64 @@ from flask_restful import fields from libs.helper import TimestampField dataset_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "created_by": fields.String, + "created_at": TimestampField, } -reranking_model_fields = { - 'reranking_provider_name': fields.String, - 'reranking_model_name': fields.String -} +reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} -keyword_setting_fields = { - 'keyword_weight': fields.Float -} +keyword_setting_fields = {"keyword_weight": fields.Float} vector_setting_fields = { - 'vector_weight': fields.Float, - 'embedding_model_name': fields.String, - 'embedding_provider_name': fields.String, + "vector_weight": fields.Float, + "embedding_model_name": fields.String, + "embedding_provider_name": fields.String, } weighted_score_fields = { - 'keyword_setting': fields.Nested(keyword_setting_fields), - 'vector_setting': fields.Nested(vector_setting_fields), + "keyword_setting": fields.Nested(keyword_setting_fields), + "vector_setting": fields.Nested(vector_setting_fields), } dataset_retrieval_model_fields = { - 'search_method': fields.String, - 'reranking_enable': fields.Boolean, - 'reranking_mode': fields.String, - 'reranking_model': fields.Nested(reranking_model_fields), - 'weights': fields.Nested(weighted_score_fields, allow_null=True), - 'top_k': fields.Integer, - 'score_threshold_enabled': fields.Boolean, - 'score_threshold': fields.Float + "search_method": fields.String, + "reranking_enable": fields.Boolean, + "reranking_mode": fields.String, + "reranking_model": fields.Nested(reranking_model_fields), + "weights": fields.Nested(weighted_score_fields, allow_null=True), + "top_k": fields.Integer, + "score_threshold_enabled": fields.Boolean, + "score_threshold": fields.Float, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} dataset_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'provider': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'app_count': fields.Integer, - 'document_count': fields.Integer, - 'word_count': fields.Integer, - 'created_by': fields.String, - 'created_at': TimestampField, - 'updated_by': fields.String, - 'updated_at': TimestampField, - 'embedding_model': fields.String, - 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean, - 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields), - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "description": fields.String, + "provider": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "app_count": fields.Integer, + "document_count": fields.Integer, + "word_count": fields.Integer, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "embedding_model": fields.String, + "embedding_model_provider": fields.String, + "embedding_available": fields.Boolean, + "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "tags": fields.List(fields.Nested(tag_fields)), } dataset_query_detail_fields = { @@ -79,7 +70,5 @@ dataset_query_detail_fields = { "source_app_id": fields.String, "created_by_role": fields.String, "created_by": fields.String, - "created_at": TimestampField + "created_at": TimestampField, } - - diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index e8215255b3..a83ec7bc97 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -4,75 +4,73 @@ from fields.dataset_fields import dataset_fields from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'doc_form': fields.String, + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "doc_form": fields.String, } document_with_segments_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } dataset_and_document_fields = { - 'dataset': fields.Nested(dataset_fields), - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String + "dataset": fields.Nested(dataset_fields), + "documents": fields.List(fields.Nested(document_fields)), + "batch": fields.String, } document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, + "id": fields.String, + "indexing_status": fields.String, + "processing_started_at": TimestampField, + "parsing_completed_at": TimestampField, + "cleaning_completed_at": TimestampField, + "splitting_completed_at": TimestampField, + "completed_at": TimestampField, + "paused_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } -document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) -} \ No newline at end of file +document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))} diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ee630c12c2..99e529f9d1 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,8 +1,8 @@ from flask_restful import fields simple_end_user_fields = { - 'id': fields.String, - 'type': fields.String, - 'is_anonymous': fields.Boolean, - 'session_id': fields.String, + "id": fields.String, + "type": fields.String, + "is_anonymous": fields.Boolean, + "session_id": fields.String, } diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc..e5a03ce77e 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -3,17 +3,17 @@ from flask_restful import fields from libs.helper import TimestampField upload_config_fields = { - 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer, - 'image_file_size_limit': fields.Integer, + "file_size_limit": fields.Integer, + "batch_count_limit": fields.Integer, + "image_file_size_limit": fields.Integer, } file_fields = { - 'id': fields.String, - 'name': fields.String, - 'size': fields.Integer, - 'extension': fields.String, - 'mime_type': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, -} \ No newline at end of file + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378..f36e80f8d4 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -3,39 +3,39 @@ from flask_restful import fields from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'data_source_type': fields.String, - 'name': fields.String, - 'doc_type': fields.String, + "id": fields.String, + "data_source_type": fields.String, + "name": fields.String, + "doc_type": fields.String, } segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'document': fields.Nested(document_fields), + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "document": fields.Nested(document_fields), } hit_testing_record_fields = { - 'segment': fields.Nested(segment_fields), - 'score': fields.Float, - 'tsne_position': fields.Raw -} \ No newline at end of file + "segment": fields.Nested(segment_fields), + "score": fields.Float, + "tsne_position": fields.Raw, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 35cc5a6475..b87cc65324 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -3,23 +3,21 @@ from flask_restful import fields from libs.helper import TimestampField app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': TimestampField, - 'editable': fields.Boolean, - 'uninstallable': fields.Boolean + "id": fields.String, + "app": fields.Nested(app_fields), + "app_owner_tenant_id": fields.String, + "is_pinned": fields.Boolean, + "last_used_at": TimestampField, + "editable": fields.Boolean, + "uninstallable": fields.Boolean, } -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} \ No newline at end of file +installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index d061b59c34..1cf8e408d1 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -2,38 +2,32 @@ from flask_restful import fields from libs.helper import TimestampField -simple_account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} +simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "is_password_set": fields.Boolean, + "interface_language": fields.String, + "interface_theme": fields.String, + "timezone": fields.String, + "last_login_at": TimestampField, + "last_login_ip": fields.String, + "created_at": TimestampField, } account_with_role_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'last_active_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "last_login_at": TimestampField, + "last_active_at": TimestampField, + "created_at": TimestampField, + "role": fields.String, + "status": fields.String, } -account_with_role_list_fields = { - 'accounts': fields.List(fields.Nested(account_with_role_fields)) -} +account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 3116843589..3d2df87afb 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,83 +3,79 @@ from flask_restful import fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String) + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index e41d1a53dd..2dd4cb45be 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -3,31 +3,31 @@ from flask_restful import fields from libs.helper import TimestampField segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, } segment_list_response = { - 'data': fields.List(fields.Nested(segment_fields)), - 'has_more': fields.Boolean, - 'limit': fields.Integer + "data": fields.List(fields.Nested(segment_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, } diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index f7e030b738..9af4fc57dd 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,8 +1,3 @@ from flask_restful import fields -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String, - 'binding_count': fields.String -} \ No newline at end of file +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index e230c159fb..a53b546249 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -7,18 +7,18 @@ from libs.helper import TimestampField workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), "created_from": fields.String, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "created_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, } workflow_app_log_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"), } diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index c1dd0e184a..240b8f2eb0 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,43 +13,43 @@ class EnvironmentVariableField(fields.Raw): # Mask secret variables values in environment_variables if isinstance(value, SecretVariable): return { - 'id': value.id, - 'name': value.name, - 'value': encrypter.obfuscated_token(value.value), - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": encrypter.obfuscated_token(value.value), + "value_type": value.value_type.value, } if isinstance(value, Variable): return { - 'id': value.id, - 'name': value.name, - 'value': value.value, - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": value.value, + "value_type": value.value_type.value, } if isinstance(value, dict): - value_type = value.get('value_type') + value_type = value.get("value_type") if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: - raise ValueError(f'Unsupported environment variable value type: {value_type}') + raise ValueError(f"Unsupported environment variable value type: {value_type}") return value conversation_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value_type': fields.String(attribute='value_type.value'), - 'value': fields.Raw, - 'description': fields.String, + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, + "description": fields.String, } workflow_fields = { - 'id': fields.String, - 'graph': fields.Raw(attribute='graph_dict'), - 'features': fields.Raw(attribute='features_dict'), - 'hash': fields.String(attribute='unique_hash'), - 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), - 'created_at': TimestampField, - 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField, - 'tool_published': fields.Boolean, - 'environment_variables': fields.List(EnvironmentVariableField()), - 'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)), + "id": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "features": fields.Raw(attribute="features_dict"), + "hash": fields.String(attribute="unique_hash"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, + "tool_published": fields.Boolean, + "environment_variables": fields.List(EnvironmentVariableField()), + "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), } diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3e798473cd..1413adf719 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -13,7 +13,7 @@ workflow_run_for_log_fields = { "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_for_list_fields = { @@ -24,9 +24,9 @@ workflow_run_for_list_fields = { "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_for_list_fields = { @@ -39,40 +39,40 @@ advanced_chat_workflow_run_for_list_fields = { "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), } workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } workflow_run_detail_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), + "graph": fields.Raw(attribute="graph_dict"), + "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), + "outputs": fields.Raw(attribute="outputs_dict"), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_node_execution_fields = { @@ -82,21 +82,21 @@ workflow_run_node_execution_fields = { "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.Raw(attribute='inputs_dict'), - "process_data": fields.Raw(attribute='process_data_dict'), - "outputs": fields.Raw(attribute='outputs_dict'), + "inputs": fields.Raw(attribute="inputs_dict"), + "process_data": fields.Raw(attribute="process_data_dict"), + "outputs": fields.Raw(attribute="outputs_dict"), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), "extras": fields.Raw, "created_at": TimestampField, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "finished_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "finished_at": TimestampField, } workflow_run_node_execution_list_fields = { - 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), + "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } diff --git a/api/pyproject.toml b/api/pyproject.toml index 82e2aaeb2b..3e107f5e9b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -69,7 +69,18 @@ ignore = [ ] [tool.ruff.format] -quote-style = "single" +exclude = [ + "core/**/*.py", + "controllers/**/*.py", + "models/**/*.py", + "utils/**/*.py", + "migrations/**/*", + "services/**/*.py", + "tasks/**/*.py", + "tests/**/*.py", + "libs/**/*.py", + "configs/**/*.py", +] [tool.pytest_env] OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index ccc1062266..67d0706828 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -11,27 +11,32 @@ from extensions.ext_database import db from models.dataset import Embedding -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_embedding_cache_task(): - click.echo(click.style('Start clean embedding cache.', fg='green')) + click.echo(click.style("Start clean embedding cache.", fg="green")) clean_days = int(dify_config.CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ - .order_by(Embedding.created_at.desc()).limit(100).all() + embedding_ids = ( + db.session.query(Embedding.id) + .filter(Embedding.created_at < thirty_days_ago) + .order_by(Embedding.created_at.desc()) + .limit(100) + .all() + ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] except NotFound: break if embedding_ids: for embedding_id in embedding_ids: - db.session.execute(text( - "DELETE FROM embeddings WHERE id = :embedding_id" - ), {'embedding_id': embedding_id}) + db.session.execute( + text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} + ) db.session.commit() else: break end_at = time.perf_counter() - click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index b2b2f82b78..3d799bfd4e 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -12,9 +12,9 @@ from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, Document -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_unused_datasets_task(): - click.echo(click.style('Start clean unused datasets indexes.', fg='green')) + click.echo(click.style("Start clean unused datasets indexes.", fg="green")) clean_days = dify_config.CLEAN_DAY_SETTING start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) @@ -22,40 +22,44 @@ def clean_unused_datasets_task(): while True: try: # Subquery for counting new documents - document_subquery_new = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at > thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Subquery for counting old documents - document_subquery_old = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at < thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Main query with join and filter - datasets = (db.session.query(Dataset) - .outerjoin( - document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id - ).outerjoin( - document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id - ).filter( - Dataset.created_at < thirty_days_ago, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0 - ).order_by( - Dataset.created_at.desc() - ).paginate(page=page, per_page=50)) + datasets = ( + db.session.query(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < thirty_days_ago, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break @@ -63,10 +67,11 @@ def clean_unused_datasets_task(): break page += 1 for dataset in datasets: - dataset_query = db.session.query(DatasetQuery).filter( - DatasetQuery.created_at > thirty_days_ago, - DatasetQuery.dataset_id == dataset.id - ).all() + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id) + .all() + ) if not dataset_query or len(dataset_query) == 0: try: # remove index @@ -74,17 +79,14 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = { - Document.enabled: False - } + update_params = {Document.enabled: False} Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() - click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), - fg='green')) + click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: click.echo( - click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) end_at = time.perf_counter() - click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/dev/reformat b/dev/reformat index f50ccb04c4..ad83e897d9 100755 --- a/dev/reformat +++ b/dev/reformat @@ -11,5 +11,8 @@ fi # run ruff linter ruff check --fix ./api +# run ruff formatter +ruff format ./api + # run dotenv-linter linter dotenv-linter ./api/.env.example ./web/.env.example