chore(api): Introduce Ruff Formatter. (#7291)

This commit is contained in:
-LAN- 2024-08-15 12:54:05 +08:00 committed by GitHub
parent 8f16165f92
commit 3571292fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 1315 additions and 1335 deletions

View File

@ -45,6 +45,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints - name: Lint hints
if: failure() if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors." run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

View File

@ -1,6 +1,6 @@
import os import os
if os.environ.get("DEBUG", "false").lower() != 'true': if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
if os.name == "nt": if os.name == "nt":
os.system('tzutil /s "UTC"') os.system('tzutil /s "UTC"')
else: else:
os.environ['TZ'] = 'UTC' os.environ["TZ"] = "UTC"
time.tzset() time.tzset()
@ -70,13 +70,14 @@ class DifyApp(Flask):
# ------------- # -------------
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# ---------------------------- # ----------------------------
# Application Factory Function # Application Factory Function
# ---------------------------- # ----------------------------
def create_flask_app_with_configs() -> Flask: def create_flask_app_with_configs() -> Flask:
""" """
create a raw flask app create a raw flask app
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
elif isinstance(value, int | float | bool): elif isinstance(value, int | float | bool):
os.environ[key] = str(value) os.environ[key] = str(value)
elif value is None: elif value is None:
os.environ[key] = '' os.environ[key] = ""
return dify_app return dify_app
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask: def create_app() -> Flask:
app = create_flask_app_with_configs() app = create_flask_app_with_configs()
app.secret_key = app.config['SECRET_KEY'] app.secret_key = app.config["SECRET_KEY"]
log_handlers = None log_handlers = None
log_file = app.config.get('LOG_FILE') log_file = app.config.get("LOG_FILE")
if log_file: if log_file:
log_dir = os.path.dirname(log_file) log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
@ -111,23 +112,24 @@ def create_app() -> Flask:
RotatingFileHandler( RotatingFileHandler(
filename=log_file, filename=log_file,
maxBytes=1024 * 1024 * 1024, maxBytes=1024 * 1024 * 1024,
backupCount=5 backupCount=5,
), ),
logging.StreamHandler(sys.stdout) logging.StreamHandler(sys.stdout),
] ]
logging.basicConfig( logging.basicConfig(
level=app.config.get('LOG_LEVEL'), level=app.config.get("LOG_LEVEL"),
format=app.config.get('LOG_FORMAT'), format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get('LOG_DATEFORMAT'), datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers, handlers=log_handlers,
force=True force=True,
) )
log_tz = app.config.get('LOG_TZ') log_tz = app.config.get("LOG_TZ")
if log_tz: if log_tz:
from datetime import datetime from datetime import datetime
import pytz import pytz
timezone = pytz.timezone(log_tz) timezone = pytz.timezone(log_tz)
def time_converter(seconds): def time_converter(seconds):
@ -162,24 +164,24 @@ def initialize_extensions(app):
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in ['console', 'inner_api']: if request.blueprint not in ["console", "inner_api"]:
return None return None
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '') auth_header = request.headers.get("Authorization", "")
if not auth_header: if not auth_header:
auth_token = request.args.get('_token') auth_token = request.args.get("_token")
if not auth_token: if not auth_token:
raise Unauthorized('Invalid Authorization token.') raise Unauthorized("Invalid Authorization token.")
else: else:
if ' ' not in auth_header: if " " not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer': if auth_scheme != "bearer":
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id') user_id = decoded.get("user_id")
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account: if account:
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
@login_manager.unauthorized_handler @login_manager.unauthorized_handler
def unauthorized_handler(): def unauthorized_handler():
"""Handle unauthorized requests.""" """Handle unauthorized requests."""
return Response(json.dumps({ return Response(
'code': 'unauthorized', json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
'message': "Unauthorized." status=401,
}), status=401, content_type="application/json") content_type="application/json",
)
# register blueprint routers # register blueprint routers
@ -204,38 +207,36 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp from controllers.web import bp as web_bp
CORS(service_api_bp, CORS(
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], service_api_bp,
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
) )
app.register_blueprint(service_api_bp) app.register_blueprint(service_api_bp)
CORS(web_bp, CORS(
resources={ web_bp,
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True, supports_credentials=True,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=['X-Version', 'X-Env'] expose_headers=["X-Version", "X-Env"],
) )
app.register_blueprint(web_bp) app.register_blueprint(web_bp)
CORS(console_app_bp, CORS(
resources={ console_app_bp,
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True, supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'], allow_headers=["Content-Type", "Authorization"],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=['X-Version', 'X-Env'] expose_headers=["X-Version", "X-Env"],
) )
app.register_blueprint(console_app_bp) app.register_blueprint(console_app_bp)
CORS(files_bp, CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(files_bp) app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp) app.register_blueprint(inner_api_bp)
@ -245,29 +246,29 @@ def register_blueprints(app):
app = create_app() app = create_app()
celery = app.extensions["celery"] celery = app.extensions["celery"]
if app.config.get('TESTING'): if app.config.get("TESTING"):
print("App is running in TESTING mode") print("App is running in TESTING mode")
@app.after_request @app.after_request
def after_request(response): def after_request(response):
"""Add Version headers to the response.""" """Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0) response.set_cookie("remember_token", "", expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION']) response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add('X-Env', app.config['DEPLOY_ENV']) response.headers.add("X-Env", app.config["DEPLOY_ENV"])
return response return response
@app.route('/health') @app.route("/health")
def health(): def health():
return Response(json.dumps({ return Response(
'pid': os.getpid(), json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
'status': 'ok', status=200,
'version': app.config['CURRENT_VERSION'] content_type="application/json",
}), status=200, content_type="application/json") )
@app.route('/threads') @app.route("/threads")
def threads(): def threads():
num_threads = threading.active_count() num_threads = threading.active_count()
threads = threading.enumerate() threads = threading.enumerate()
@ -278,32 +279,34 @@ def threads():
thread_id = thread.ident thread_id = thread.ident
is_alive = thread.is_alive() is_alive = thread.is_alive()
thread_list.append({ thread_list.append(
'name': thread_name, {
'id': thread_id, "name": thread_name,
'is_alive': is_alive "id": thread_id,
}) "is_alive": is_alive,
}
)
return { return {
'pid': os.getpid(), "pid": os.getpid(),
'thread_num': num_threads, "thread_num": num_threads,
'threads': thread_list "threads": thread_list,
} }
@app.route('/db-pool-stat') @app.route("/db-pool-stat")
def pool_stat(): def pool_stat():
engine = db.engine engine = db.engine
return { return {
'pid': os.getpid(), "pid": os.getpid(),
'pool_size': engine.pool.size(), "pool_size": engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(), "checked_in_connections": engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(), "checked_out_connections": engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(), "overflow_connections": engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(), "connection_timeout": engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle "recycle_time": db.engine.pool._recycle,
} }
if __name__ == '__main__': if __name__ == "__main__":
app.run(host='0.0.0.0', port=5001) app.run(host="0.0.0.0", port=5001)

View File

@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.') @click.command("reset-password", help="Reset the account password.")
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') @click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option('--new-password', prompt=True, help='the new password.') @click.option("--new-password", prompt=True, help="the new password.")
@click.option('--password-confirm', prompt=True, help='the new password confirm.') @click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm): def reset_password(email, new_password, password_confirm):
""" """
Reset password of owner account Reset password of owner account
Only available in SELF_HOSTED mode Only available in SELF_HOSTED mode
""" """
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style('sorry. The two passwords do not match.', fg='red')) click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return return
account = db.session.query(Account). \ account = db.session.query(Account).filter(Account.email == email).one_or_none()
filter(Account.email == email). \
one_or_none()
if not account: if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return return
try: try:
valid_password(new_password) valid_password(new_password)
except: except:
click.echo( click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
return return
# generate password salt # generate password salt
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
db.session.commit() db.session.commit()
click.echo(click.style('Congratulations! Password has been reset.', fg='green')) click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command('reset-email', help='Reset the account email.') @click.command("reset-email", help="Reset the account email.")
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') @click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option('--new-email', prompt=True, help='the new email.') @click.option("--new-email", prompt=True, help="the new email.")
@click.option('--email-confirm', prompt=True, help='the new email confirm.') @click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm): def reset_email(email, new_email, email_confirm):
""" """
Replace account email Replace account email
:return: :return:
""" """
if str(new_email).strip() != str(email_confirm).strip(): if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return return
account = db.session.query(Account). \ account = db.session.query(Account).filter(Account.email == email).one_or_none()
filter(Account.email == email). \
one_or_none()
if not account: if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return return
try: try:
email_validate(new_email) email_validate(new_email)
except: except:
click.echo( click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
return return
account.email = new_email account.email = new_email
db.session.commit() db.session.commit()
click.echo(click.style('Congratulations!, email has been reset.', fg='green')) click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' @click.command(
'After the reset, all LLM credentials will become invalid, ' "reset-encrypt-key-pair",
'requiring re-entry.' help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
'Only support SELF_HOSTED mode.') "After the reset, all LLM credentials will become invalid, "
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' "requiring re-entry."
' this operation cannot be rolled back!', fg='red')) "Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair(): def reset_encrypt_key_pair():
""" """
Reset the encrypted key pair of workspace for encrypt LLM credentials. Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry. After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode. Only support SELF_HOSTED mode.
""" """
if dify_config.EDITION != 'SELF_HOSTED': if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return return
tenants = db.session.query(Tenant).all() tenants = db.session.query(Tenant).all()
for tenant in tenants: for tenant in tenants:
if not tenant: if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return return
tenant.encrypt_public_key = generate_key_pair(tenant.id) tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit() db.session.commit()
click.echo(click.style('Congratulations! ' click.echo(
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@click.command('vdb-migrate', help='migrate vector db.') @click.command("vdb-migrate", help="migrate vector db.")
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str): def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']: if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database() migrate_knowledge_vector_database()
if scope in ['annotation', 'all']: if scope in ["annotation", "all"]:
migrate_annotation_vector_database() migrate_annotation_vector_database()
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
""" """
Migrate annotation datas to target vector database . Migrate annotation datas to target vector database .
""" """
click.echo(click.style('Start migrate annotation data.', fg='green')) click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0 create_count = 0
skipped_count = 0 skipped_count = 0
total_count = 0 total_count = 0
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
while True: while True:
try: try:
# get apps info # get apps info
apps = db.session.query(App).filter( apps = (
App.status == 'normal' db.session.query(App)
).order_by(App.created_at.desc()).paginate(page=page, per_page=50) .filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound: except NotFound:
break break
page += 1 page += 1
for app in apps: for app in apps:
total_count = total_count + 1 total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. ' click.echo(
+ f'{create_count} created, {skipped_count} skipped.') f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try: try:
click.echo('Create app annotation index: {}'.format(app.id)) click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter( app_annotation_setting = (
AppAnnotationSetting.app_id == app.id db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
).first() )
if not app_annotation_setting: if not app_annotation_setting:
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id)) click.echo("App annotation setting is disabled: {}".format(app.id))
continue continue
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( dataset_collection_binding = (
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id db.session.query(DatasetCollectionBinding)
).first() .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id)) click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
indexing_technique='high_quality', indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name, embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name, embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id,
) )
documents = [] documents = []
if annotations: if annotations:
for annotation in annotations: for annotation in annotations:
document = Document( document = Document(
page_content=annotation.question, page_content=annotation.question,
metadata={ metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
) )
documents.append(document) documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Start to migrate annotation, app_id: {app.id}.") click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try: try:
vector.delete() vector.delete()
click.echo( click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
except Exception as e: except Exception as e:
click.echo( click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
raise e raise e
if documents: if documents:
try: try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo( click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green')) click.style(
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e: except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e raise e
click.echo(f'Successfully migrated app annotation {app.id}.') click.echo(f"Successfully migrated app annotation {app.id}.")
create_count += 1 create_count += 1
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), click.style(
fg='red')) "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue continue
click.echo( click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', click.style(
fg='green')) f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
def migrate_knowledge_vector_database(): def migrate_knowledge_vector_database():
""" """
Migrate vector database datas to target vector database . Migrate vector database datas to target vector database .
""" """
click.echo(click.style('Start migrate vector db.', fg='green')) click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0 create_count = 0
skipped_count = 0 skipped_count = 0
total_count = 0 total_count = 0
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
page = 1 page = 1
while True: while True:
try: try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ datasets = (
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound: except NotFound:
break break
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
total_count = total_count + 1 total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. ' click.echo(
+ f'{create_count} created, {skipped_count} skipped.') f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try: try:
click.echo('Create dataset vdb index: {}'.format(dataset.id)) click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict: if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type: if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
continue continue
collection_name = '' collection_name = ""
if vector_type == VectorType.WEAVIATE: if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT: elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ dataset_collection_binding = (
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ db.session.query(DatasetCollectionBinding)
one_or_none() .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding: if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name collection_name = dataset_collection_binding.collection_name
else: else:
raise ValueError('Dataset Collection Bindings is not exist!') raise ValueError("Dataset Collection Bindings is not exist!")
else: else:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS: elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT: elif vector_type == VectorType.RELYT:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT: elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR: elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH: elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": VectorType.OPENSEARCH, "type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name},
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB: elif vector_type == VectorType.ANALYTICDB:
@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": VectorType.ANALYTICDB, "type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name},
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH: elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id) index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
"type": 'elasticsearch',
"vector_store": {"class_prefix": index_name}
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")
@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
try: try:
vector.delete() vector.delete()
click.echo( click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', click.style(
fg='green')) f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', click.style(
fg='red')) f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
raise e raise e
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed', DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
).all() )
.all()
)
documents = [] documents = []
segments_count = 0 segments_count = 0
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter( segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed', DocumentSegment.status == "completed",
DocumentSegment.enabled == True DocumentSegment.enabled == True,
).all() )
.all()
)
for segment in segments: for segment in segments:
document = Document( document = Document(
@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
"doc_hash": segment.index_node_hash, "doc_hash": segment.index_node_hash,
"document_id": segment.document_id, "document_id": segment.document_id,
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
} },
) )
documents.append(document) documents.append(document)
@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
if documents: if documents:
try: try:
click.echo(click.style( click.echo(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', click.style(
fg='green')) f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents) vector.create(documents)
click.echo( click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e: except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e raise e
db.session.add(dataset) db.session.add(dataset)
db.session.commit() db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.') click.echo(f"Successfully migrated dataset {dataset.id}.")
create_count += 1 create_count += 1
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
fg='red')) )
continue continue
click.echo( click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', click.style(
fg='green')) f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') @click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
def convert_to_agent_apps(): def convert_to_agent_apps():
""" """
Convert Agent Assistant to Agent App. Convert Agent Assistant to Agent App.
""" """
click.echo(click.style('Start convert to agent apps.', fg='green')) click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = [] proceeded_app_ids = []
@ -466,7 +480,7 @@ def convert_to_agent_apps():
break break
for app in apps: for app in apps:
click.echo('Converting app: {}'.format(app.id)) click.echo("Converting app: {}".format(app.id))
try: try:
app.mode = AppMode.AGENT_CHAT.value app.mode = AppMode.AGENT_CHAT.value
@ -478,137 +492,139 @@ def convert_to_agent_apps():
) )
db.session.commit() db.session.commit()
click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
except Exception as e: except Exception as e:
click.echo( click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') @click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') @click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str): def add_qdrant_doc_id_index(field: str):
click.echo(click.style('Start add qdrant doc_id index.', fg='green')) click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant": if vector_type != "qdrant":
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return return
create_count = 0 create_count = 0
try: try:
bindings = db.session.query(DatasetCollectionBinding).all() bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings: if not bindings:
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return return
import qdrant_client import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings: for binding in bindings:
if dify_config.QDRANT_URL is None: if dify_config.QDRANT_URL is None:
raise ValueError('Qdrant url is required.') raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig( qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL, endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY, api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.root_path, root_path=current_app.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT, timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT, grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
) )
try: try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index # create payload index
client.create_payload_index(binding.collection_name, field, client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
field_schema=PayloadSchemaType.KEYWORD)
create_count += 1 create_count += 1
except UnexpectedResponse as e: except UnexpectedResponse as e:
# Collection does not exist, so return # Collection does not exist, so return
if e.status_code == 404: if e.status_code == 404:
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue continue
# Some other error occurred, so re-raise the exception # Some other error occurred, so re-raise the exception
else: else:
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) click.echo(
click.style(
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e: except Exception as e:
click.echo(click.style('Failed to create qdrant client.', fg='red')) click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo( click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
@click.command('create-tenant', help='Create account and tenant.') @click.command("create-tenant", help="Create account and tenant.")
@click.option('--email', prompt=True, help='The email address of the tenant account.') @click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option('--language', prompt=True, help='Account language, default: en-US.') @click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None): def create_tenant(email: str, language: Optional[str] = None):
""" """
Create tenant account Create tenant account
""" """
if not email: if not email:
click.echo(click.style('Sorry, email is required.', fg='red')) click.echo(click.style("Sorry, email is required.", fg="red"))
return return
# Create account # Create account
email = email.strip() email = email.strip()
if '@' not in email: if "@" not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red')) click.echo(click.style("Sorry, invalid email address.", fg="red"))
return return
account_name = email.split('@')[0] account_name = email.split("@")[0]
if language not in languages: if language not in languages:
language = 'en-US' language = "en-US"
# generate random password # generate random password
new_password = secrets.token_urlsafe(16) new_password = secrets.token_urlsafe(16)
# register account # register account
account = RegisterService.register( account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
email=email,
name=account_name,
password=new_password,
language=language
)
TenantService.create_owner_tenant_if_not_exist(account) TenantService.create_owner_tenant_if_not_exist(account)
click.echo(click.style('Congratulations! Account and tenant created.\n' click.echo(
'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command('upgrade-db', help='upgrade the database') @click.command("upgrade-db", help="upgrade the database")
def upgrade_db(): def upgrade_db():
click.echo('Preparing database migration...') click.echo("Preparing database migration...")
lock = redis_client.lock(name='db_upgrade_lock', timeout=60) lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False): if lock.acquire(blocking=False):
try: try:
click.echo(click.style('Start database migration.', fg='green')) click.echo(click.style("Start database migration.", fg="green"))
# run db migration # run db migration
import flask_migrate import flask_migrate
flask_migrate.upgrade() flask_migrate.upgrade()
click.echo(click.style('Database migration successful!', fg='green')) click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e: except Exception as e:
logging.exception(f'Database migration failed, error: {e}') logging.exception(f"Database migration failed, error: {e}")
finally: finally:
lock.release() lock.release()
else: else:
click.echo('Database migration skipped') click.echo("Database migration skipped")
@click.command('fix-app-site-missing', help='Fix app related site missing issue.') @click.command("fix-app-site-missing", help="Fix app related site missing issue.")
def fix_app_site_missing(): def fix_app_site_missing():
""" """
Fix app related site missing issue. Fix app related site missing issue.
""" """
click.echo(click.style('Start fix app related site missing issue.', fg='green')) click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = [] failed_app_ids = []
while True: while True:
@ -639,15 +655,14 @@ where sites.id is null limit 1000"""
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
except Exception as e: except Exception as e:
failed_app_ids.append(app_id) failed_app_ids.append(app_id)
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red')) click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f'Fix app related site missing issue failed, error: {e}') logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue continue
if not processed_count: if not processed_count:
break break
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
def register_commands(app): def register_commands(app):

View File

@ -1 +1 @@
HIDDEN_VALUE = '[__HIDDEN__]' HIDDEN_VALUE = "[__HIDDEN__]"

View File

@ -1,22 +1,22 @@
language_timezone_mapping = { language_timezone_mapping = {
'en-US': 'America/New_York', "en-US": "America/New_York",
'zh-Hans': 'Asia/Shanghai', "zh-Hans": "Asia/Shanghai",
'zh-Hant': 'Asia/Taipei', "zh-Hant": "Asia/Taipei",
'pt-BR': 'America/Sao_Paulo', "pt-BR": "America/Sao_Paulo",
'es-ES': 'Europe/Madrid', "es-ES": "Europe/Madrid",
'fr-FR': 'Europe/Paris', "fr-FR": "Europe/Paris",
'de-DE': 'Europe/Berlin', "de-DE": "Europe/Berlin",
'ja-JP': 'Asia/Tokyo', "ja-JP": "Asia/Tokyo",
'ko-KR': 'Asia/Seoul', "ko-KR": "Asia/Seoul",
'ru-RU': 'Europe/Moscow', "ru-RU": "Europe/Moscow",
'it-IT': 'Europe/Rome', "it-IT": "Europe/Rome",
'uk-UA': 'Europe/Kyiv', "uk-UA": "Europe/Kyiv",
'vi-VN': 'Asia/Ho_Chi_Minh', "vi-VN": "Asia/Ho_Chi_Minh",
'ro-RO': 'Europe/Bucharest', "ro-RO": "Europe/Bucharest",
'pl-PL': 'Europe/Warsaw', "pl-PL": "Europe/Warsaw",
'hi-IN': 'Asia/Kolkata', "hi-IN": "Asia/Kolkata",
'tr-TR': 'Europe/Istanbul', "tr-TR": "Europe/Istanbul",
'fa-IR': 'Asia/Tehran', "fa-IR": "Asia/Tehran",
} }
languages = list(language_timezone_mapping.keys()) languages = list(language_timezone_mapping.keys())
@ -26,6 +26,5 @@ def supported_language(lang):
if lang in languages: if lang in languages:
return lang return lang
error = ('{lang} is not a valid language.' error = "{lang} is not a valid language.".format(lang=lang)
.format(lang=lang))
raise ValueError(error) raise ValueError(error)

View File

@ -5,82 +5,79 @@ from models.model import AppMode
default_app_templates = { default_app_templates = {
# workflow default mode # workflow default mode
AppMode.WORKFLOW: { AppMode.WORKFLOW: {
'app': { "app": {
'mode': AppMode.WORKFLOW.value, "mode": AppMode.WORKFLOW.value,
'enable_site': True, "enable_site": True,
'enable_api': True "enable_api": True,
} }
}, },
# completion default mode # completion default mode
AppMode.COMPLETION: { AppMode.COMPLETION: {
'app': { "app": {
'mode': AppMode.COMPLETION.value, "mode": AppMode.COMPLETION.value,
'enable_site': True, "enable_site": True,
'enable_api': True "enable_api": True,
}, },
'model_config': { "model_config": {
'model': { "model": {
"provider": "openai", "provider": "openai",
"name": "gpt-4o", "name": "gpt-4o",
"mode": "chat", "mode": "chat",
"completion_params": {} "completion_params": {},
}, },
'user_input_form': json.dumps([ "user_input_form": json.dumps(
[
{ {
"paragraph": { "paragraph": {
"label": "Query", "label": "Query",
"variable": "query", "variable": "query",
"required": True, "required": True,
"default": "" "default": "",
} },
} },
]), ]
'pre_prompt': '{{query}}' ),
"pre_prompt": "{{query}}",
}, },
}, },
# chat default mode # chat default mode
AppMode.CHAT: { AppMode.CHAT: {
'app': { "app": {
'mode': AppMode.CHAT.value, "mode": AppMode.CHAT.value,
'enable_site': True, "enable_site": True,
'enable_api': True "enable_api": True,
}, },
'model_config': { "model_config": {
'model': { "model": {
"provider": "openai", "provider": "openai",
"name": "gpt-4o", "name": "gpt-4o",
"mode": "chat", "mode": "chat",
"completion_params": {} "completion_params": {},
} },
} },
}, },
# advanced-chat default mode # advanced-chat default mode
AppMode.ADVANCED_CHAT: { AppMode.ADVANCED_CHAT: {
'app': { "app": {
'mode': AppMode.ADVANCED_CHAT.value, "mode": AppMode.ADVANCED_CHAT.value,
'enable_site': True, "enable_site": True,
'enable_api': True "enable_api": True,
} },
}, },
# agent-chat default mode # agent-chat default mode
AppMode.AGENT_CHAT: { AppMode.AGENT_CHAT: {
'app': { "app": {
'mode': AppMode.AGENT_CHAT.value, "mode": AppMode.AGENT_CHAT.value,
'enable_site': True, "enable_site": True,
'enable_api': True "enable_api": True,
}, },
'model_config': { "model_config": {
'model': { "model": {
"provider": "openai", "provider": "openai",
"name": "gpt-4o", "name": "gpt-4o",
"mode": "chat", "mode": "chat",
"completion_params": {} "completion_params": {},
} },
} },
} },
} }

View File

@ -2,6 +2,6 @@ from contextvars import ContextVar
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar('tenant_id') tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

View File

@ -1,13 +1,13 @@
from blinker import signal from blinker import signal
# sender: app # sender: app
app_was_created = signal('app-was-created') app_was_created = signal("app-was-created")
# sender: app, kwargs: app_model_config # sender: app, kwargs: app_model_config
app_model_config_was_updated = signal('app-model-config-was-updated') app_model_config_was_updated = signal("app-model-config-was-updated")
# sender: app, kwargs: published_workflow # sender: app, kwargs: published_workflow
app_published_workflow_was_updated = signal('app-published-workflow-was-updated') app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
# sender: app, kwargs: synced_draft_workflow # sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")

View File

@ -1,4 +1,4 @@
from blinker import signal from blinker import signal
# sender: dataset # sender: dataset
dataset_was_deleted = signal('dataset-was-deleted') dataset_was_deleted = signal("dataset-was-deleted")

View File

@ -1,4 +1,4 @@
from blinker import signal from blinker import signal
# sender: document # sender: document
document_was_deleted = signal('document-was-deleted') document_was_deleted = signal("document-was-deleted")

View File

@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, clean_dataset_task.delay(
dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) dataset.id,
dataset.tenant_id,
dataset.indexing_technique,
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
)

View File

@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task
@document_was_deleted.connect @document_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
document_id = sender document_id = sender
dataset_id = kwargs.get('dataset_id') dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get('doc_form') doc_form = kwargs.get("doc_form")
file_id = kwargs.get('file_id') file_id = kwargs.get("file_id")
clean_document_task.delay(document_id, dataset_id, doc_form, file_id) clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -14,21 +14,25 @@ from models.dataset import Document
@document_index_created.connect @document_index_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset_id = sender dataset_id = sender
document_ids = kwargs.get('document_ids', None) document_ids = kwargs.get("document_ids", None)
documents = [] documents = []
start_at = time.perf_counter() start_at = time.perf_counter()
for document_id in document_ids: for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = db.session.query(Document).filter( document = (
db.session.query(Document)
.filter(
Document.id == document_id, Document.id == document_id,
Document.dataset_id == dataset_id Document.dataset_id == dataset_id,
).first() )
.first()
)
if not document: if not document:
raise NotFound('Document not found') raise NotFound("Document not found")
document.indexing_status = 'parsing' document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document) documents.append(document)
db.session.add(document) db.session.add(document)
@ -38,8 +42,8 @@ def handle(sender, **kwargs):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex: except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow')) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -10,7 +10,7 @@ def handle(sender, **kwargs):
installed_app = InstalledApp( installed_app = InstalledApp(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_id=app.id, app_id=app.id,
app_owner_tenant_id=app.tenant_id app_owner_tenant_id=app.tenant_id,
) )
db.session.add(installed_app) db.session.add(installed_app)
db.session.commit() db.session.commit()

View File

@ -7,15 +7,15 @@ from models.model import Site
def handle(sender, **kwargs): def handle(sender, **kwargs):
"""Create site record when an app is created.""" """Create site record when an app is created."""
app = sender app = sender
account = kwargs.get('account') account = kwargs.get("account")
site = Site( site = Site(
app_id=app.id, app_id=app.id,
title=app.name, title=app.name,
icon = app.icon, icon=app.icon,
icon_background = app.icon_background, icon_background=app.icon_background,
default_language=account.interface_language, default_language=account.interface_language,
customize_token_strategy='not_allow', customize_token_strategy="not_allow",
code=Site.generate_code(16) code=Site.generate_code(16),
) )
db.session.add(site) db.session.add(site)

View File

@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity = kwargs.get('application_generate_entity') application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return return
@ -39,7 +39,7 @@ def handle(sender, **kwargs):
elif quota_unit == QuotaUnit.CREDITS: elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1 used_quota = 1
if 'gpt-4' in model_config.model: if "gpt-4" in model_config.model:
used_quota = 20 used_quota = 20
else: else:
used_quota = 1 used_quota = 1
@ -50,6 +50,6 @@ def handle(sender, **kwargs):
Provider.provider_name == model_config.provider, Provider.provider_name == model_config.provider,
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value, Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used Provider.quota_limit > Provider.quota_used,
).update({'quota_used': Provider.quota_used + used_quota}) ).update({"quota_used": Provider.quota_used + used_quota})
db.session.commit() db.session.commit()

View File

@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced
@app_draft_workflow_was_synced.connect @app_draft_workflow_was_synced.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
app = sender app = sender
for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
if node_data.get('data', {}).get('type') == NodeType.TOOL.value: if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try: try:
tool_entity = ToolEntity(**node_data["data"]) tool_entity = ToolEntity(**node_data["data"])
tool_runtime = ToolManager.get_tool_runtime( tool_runtime = ToolManager.get_tool_runtime(
@ -23,7 +23,7 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name, provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type, provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
) )
manager.delete_tool_parameters_cache() manager.delete_tool_parameters_cache()
except: except:

View File

@ -1,4 +1,4 @@
from blinker import signal from blinker import signal
# sender: document # sender: document
document_index_created = signal('document-index-created') document_index_created = signal("document-index-created")

View File

@ -7,13 +7,11 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect @app_model_config_was_updated.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
app = sender app = sender
app_model_config = kwargs.get('app_model_config') app_model_config = kwargs.get("app_model_config")
dataset_ids = get_dataset_ids_from_model_config(app_model_config) dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter( app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
AppDatasetJoin.app_id == app.id
).all()
removed_dataset_ids = [] removed_dataset_ids = []
if not app_dataset_joins: if not app_dataset_joins:
@ -29,16 +27,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids: if removed_dataset_ids:
for dataset_id in removed_dataset_ids: for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter( db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.dataset_id == dataset_id
).delete() ).delete()
if added_dataset_ids: if added_dataset_ids:
for dataset_id in added_dataset_ids: for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin( app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
app_id=app.id,
dataset_id=dataset_id
)
db.session.add(app_dataset_join) db.session.add(app_dataset_join)
db.session.commit() db.session.commit()
@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
agent_mode = app_model_config.agent_mode_dict agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get('tools', []) or [] tools = agent_mode.get("tools", []) or []
for tool in tools: for tool in tools:
if len(list(tool.keys())) != 1: if len(list(tool.keys())) != 1:
continue continue
@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
# get dataset from dataset_configs # get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict dataset_configs = app_model_config.dataset_configs_dict
datasets = dataset_configs.get('datasets', {}) or {} datasets = dataset_configs.get("datasets", {}) or {}
for dataset in datasets.get('datasets', []) or []: for dataset in datasets.get("datasets", []) or []:
keys = list(dataset.keys()) keys = list(dataset.keys())
if len(keys) == 1 and keys[0] == 'dataset': if len(keys) == 1 and keys[0] == "dataset":
if dataset['dataset'].get('id'): if dataset["dataset"].get("id"):
dataset_ids.add(dataset['dataset'].get('id')) dataset_ids.add(dataset["dataset"].get("id"))
return dataset_ids return dataset_ids

View File

@ -11,13 +11,11 @@ from models.workflow import Workflow
@app_published_workflow_was_updated.connect @app_published_workflow_was_updated.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
app = sender app = sender
published_workflow = kwargs.get('published_workflow') published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow) published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter( app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
AppDatasetJoin.app_id == app.id
).all()
removed_dataset_ids = [] removed_dataset_ids = []
if not app_dataset_joins: if not app_dataset_joins:
@ -33,16 +31,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids: if removed_dataset_ids:
for dataset_id in removed_dataset_ids: for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter( db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.dataset_id == dataset_id
).delete() ).delete()
if added_dataset_ids: if added_dataset_ids:
for dataset_id in added_dataset_ids: for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin( app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
app_id=app.id,
dataset_id=dataset_id
)
db.session.add(app_dataset_join) db.session.add(app_dataset_join)
db.session.commit() db.session.commit()
@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
if not graph: if not graph:
return dataset_ids return dataset_ids
nodes = graph.get('nodes', []) nodes = graph.get("nodes", [])
# fetch all knowledge retrieval nodes # fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [node for node in nodes knowledge_retrieval_nodes = [
if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
]
if not knowledge_retrieval_nodes: if not knowledge_retrieval_nodes:
return dataset_ids return dataset_ids
for node in knowledge_retrieval_nodes: for node in knowledge_retrieval_nodes:
try: try:
node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
dataset_ids.update(node_data.dataset_ids) dataset_ids.update(node_data.dataset_ids)
except Exception as e: except Exception as e:
continue continue

View File

@ -9,13 +9,13 @@ from models.provider import Provider
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity = kwargs.get('application_generate_entity') application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return return
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider Provider.provider_name == application_generate_entity.model_conf.provider,
).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
db.session.commit() db.session.commit()

View File

@ -1,4 +1,4 @@
from blinker import signal from blinker import signal
# sender: message, kwargs: conversation # sender: message, kwargs: conversation
message_was_created = signal('message-was-created') message_was_created = signal("message-was-created")

View File

@ -1,7 +1,7 @@
from blinker import signal from blinker import signal
# sender: tenant # sender: tenant
tenant_was_created = signal('tenant-was-created') tenant_was_created = signal("tenant-was-created")
# sender: tenant # sender: tenant
tenant_was_updated = signal('tenant-was-updated') tenant_was_updated = signal("tenant-was-updated")

View File

@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery:
] ]
day = app.config["CELERY_BEAT_SCHEDULER_TIME"] day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
beat_schedule = { beat_schedule = {
'clean_embedding_cache_task': { "clean_embedding_cache_task": {
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
'schedule': timedelta(days=day), "schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
}, },
'clean_unused_datasets_task': {
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
'schedule': timedelta(days=day),
} }
} celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
celery_app.conf.update(
beat_schedule=beat_schedule,
imports=imports
)
return celery_app return celery_app

View File

@ -2,15 +2,14 @@ from flask import Flask
def init_app(app: Flask): def init_app(app: Flask):
if app.config.get('API_COMPRESSION_ENABLED'): if app.config.get("API_COMPRESSION_ENABLED"):
from flask_compress import Compress from flask_compress import Compress
app.config['COMPRESS_MIMETYPES'] = [ app.config["COMPRESS_MIMETYPES"] = [
'application/json', "application/json",
'image/svg+xml', "image/svg+xml",
'text/html', "text/html",
] ]
compress = Compress() compress = Compress()
compress.init_app(app) compress.init_app(app)

View File

@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData from sqlalchemy import MetaData
POSTGRES_INDEXES_NAMING_CONVENTION = { POSTGRES_INDEXES_NAMING_CONVENTION = {
'ix': '%(column_0_label)s_idx', "ix": "%(column_0_label)s_idx",
'uq': '%(table_name)s_%(column_0_name)s_key', "uq": "%(table_name)s_%(column_0_name)s_key",
'ck': '%(table_name)s_%(constraint_name)s_check', "ck": "%(table_name)s_%(constraint_name)s_check",
'fk': '%(table_name)s_%(column_0_name)s_fkey', "fk": "%(table_name)s_%(column_0_name)s_fkey",
'pk': '%(table_name)s_pkey', "pk": "%(table_name)s_pkey",
} }
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)

View File

@ -14,67 +14,69 @@ class Mail:
return self._client is not None return self._client is not None
def init_app(self, app: Flask): def init_app(self, app: Flask):
if app.config.get('MAIL_TYPE'): if app.config.get("MAIL_TYPE"):
if app.config.get('MAIL_DEFAULT_SEND_FROM'): if app.config.get("MAIL_DEFAULT_SEND_FROM"):
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
if app.config.get('MAIL_TYPE') == 'resend': if app.config.get("MAIL_TYPE") == "resend":
api_key = app.config.get('RESEND_API_KEY') api_key = app.config.get("RESEND_API_KEY")
if not api_key: if not api_key:
raise ValueError('RESEND_API_KEY is not set') raise ValueError("RESEND_API_KEY is not set")
api_url = app.config.get('RESEND_API_URL') api_url = app.config.get("RESEND_API_URL")
if api_url: if api_url:
resend.api_url = api_url resend.api_url = api_url
resend.api_key = api_key resend.api_key = api_key
self._client = resend.Emails self._client = resend.Emails
elif app.config.get('MAIL_TYPE') == 'smtp': elif app.config.get("MAIL_TYPE") == "smtp":
from libs.smtp import SMTPClient from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
self._client = SMTPClient( self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'), server=app.config.get("SMTP_SERVER"),
port=app.config.get('SMTP_PORT'), port=app.config.get("SMTP_PORT"),
username=app.config.get('SMTP_USERNAME'), username=app.config.get("SMTP_USERNAME"),
password=app.config.get('SMTP_PASSWORD'), password=app.config.get("SMTP_PASSWORD"),
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'), _from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
use_tls=app.config.get('SMTP_USE_TLS'), use_tls=app.config.get("SMTP_USE_TLS"),
opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
) )
else: else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
else: else:
logging.warning('MAIL_TYPE is not set') logging.warning("MAIL_TYPE is not set")
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client: if not self._client:
raise ValueError('Mail client is not initialized') raise ValueError("Mail client is not initialized")
if not from_ and self._default_send_from: if not from_ and self._default_send_from:
from_ = self._default_send_from from_ = self._default_send_from
if not from_: if not from_:
raise ValueError('mail from is not set') raise ValueError("mail from is not set")
if not to: if not to:
raise ValueError('mail to is not set') raise ValueError("mail to is not set")
if not subject: if not subject:
raise ValueError('mail subject is not set') raise ValueError("mail subject is not set")
if not html: if not html:
raise ValueError('mail html is not set') raise ValueError("mail html is not set")
self._client.send({ self._client.send(
{
"from": from_, "from": from_,
"to": to, "to": to,
"subject": subject, "subject": subject,
"html": html "html": html,
}) }
)
def init_app(app: Flask): def init_app(app: Flask):

View File

@ -6,18 +6,21 @@ redis_client = redis.Redis()
def init_app(app): def init_app(app):
connection_class = Connection connection_class = Connection
if app.config.get('REDIS_USE_SSL'): if app.config.get("REDIS_USE_SSL"):
connection_class = SSLConnection connection_class = SSLConnection
redis_client.connection_pool = redis.ConnectionPool(**{ redis_client.connection_pool = redis.ConnectionPool(
'host': app.config.get('REDIS_HOST'), **{
'port': app.config.get('REDIS_PORT'), "host": app.config.get("REDIS_HOST"),
'username': app.config.get('REDIS_USERNAME'), "port": app.config.get("REDIS_PORT"),
'password': app.config.get('REDIS_PASSWORD'), "username": app.config.get("REDIS_USERNAME"),
'db': app.config.get('REDIS_DB'), "password": app.config.get("REDIS_PASSWORD"),
'encoding': 'utf-8', "db": app.config.get("REDIS_DB"),
'encoding_errors': 'strict', "encoding": "utf-8",
'decode_responses': False "encoding_errors": "strict",
}, connection_class=connection_class) "decode_responses": False,
},
connection_class=connection_class,
)
app.extensions['redis'] = redis_client app.extensions["redis"] = redis_client

View File

@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException
def init_app(app): def init_app(app):
if app.config.get('SENTRY_DSN'): if app.config.get("SENTRY_DSN"):
sentry_sdk.init( sentry_sdk.init(
dsn=app.config.get('SENTRY_DSN'), dsn=app.config.get("SENTRY_DSN"),
integrations=[ integrations=[FlaskIntegration(), CeleryIntegration()],
FlaskIntegration(),
CeleryIntegration()
],
ignore_errors=[HTTPException, ValueError], ignore_errors=[HTTPException, ValueError],
traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
environment=app.config.get('DEPLOY_ENV'), environment=app.config.get("DEPLOY_ENV"),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
) )

View File

@ -17,31 +17,19 @@ class Storage:
self.storage_runner = None self.storage_runner = None
def init_app(self, app: Flask): def init_app(self, app: Flask):
storage_type = app.config.get('STORAGE_TYPE') storage_type = app.config.get("STORAGE_TYPE")
if storage_type == 's3': if storage_type == "s3":
self.storage_runner = S3Storage( self.storage_runner = S3Storage(app=app)
app=app elif storage_type == "azure-blob":
) self.storage_runner = AzureStorage(app=app)
elif storage_type == 'azure-blob': elif storage_type == "aliyun-oss":
self.storage_runner = AzureStorage( self.storage_runner = AliyunStorage(app=app)
app=app elif storage_type == "google-storage":
) self.storage_runner = GoogleStorage(app=app)
elif storage_type == 'aliyun-oss': elif storage_type == "tencent-cos":
self.storage_runner = AliyunStorage( self.storage_runner = TencentStorage(app=app)
app=app elif storage_type == "oci-storage":
) self.storage_runner = OCIStorage(app=app)
elif storage_type == 'google-storage':
self.storage_runner = GoogleStorage(
app=app
)
elif storage_type == 'tencent-cos':
self.storage_runner = TencentStorage(
app=app
)
elif storage_type == 'oci-storage':
self.storage_runner = OCIStorage(
app=app
)
else: else:
self.storage_runner = LocalStorage(app=app) self.storage_runner = LocalStorage(app=app)

View File

@ -8,23 +8,22 @@ from extensions.storage.base_storage import BaseStorage
class AliyunStorage(BaseStorage): class AliyunStorage(BaseStorage):
"""Implementation for aliyun storage. """Implementation for aliyun storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME') self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME")
oss_auth_method = aliyun_s3.Auth oss_auth_method = aliyun_s3.Auth
region = None region = None
if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4': if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4":
oss_auth_method = aliyun_s3.AuthV4 oss_auth_method = aliyun_s3.AuthV4
region = app_config.get('ALIYUN_OSS_REGION') region = app_config.get("ALIYUN_OSS_REGION")
oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')) oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY"))
self.client = aliyun_s3.Bucket( self.client = aliyun_s3.Bucket(
oss_auth, oss_auth,
app_config.get('ALIYUN_OSS_ENDPOINT'), app_config.get("ALIYUN_OSS_ENDPOINT"),
self.bucket_name, self.bucket_name,
connect_timeout=30, connect_timeout=30,
region=region, region=region,

View File

@ -9,16 +9,15 @@ from extensions.storage.base_storage import BaseStorage
class AzureStorage(BaseStorage): class AzureStorage(BaseStorage):
"""Implementation for azure storage. """Implementation for azure storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME")
self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL")
self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME")
self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY")
def save(self, filename, data): def save(self, filename, data):
client = self._sync_client() client = self._sync_client()
@ -39,6 +38,7 @@ class AzureStorage(BaseStorage):
blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob.download_blob() blob_data = blob.download_blob()
yield from blob_data.chunks() yield from blob_data.chunks()
return generate(filename) return generate(filename)
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
@ -62,17 +62,17 @@ class AzureStorage(BaseStorage):
blob_container.delete_blob(filename) blob_container.delete_blob(filename)
def _sync_client(self): def _sync_client(self):
cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
cache_result = redis_client.get(cache_key) cache_result = redis_client.get(cache_key)
if cache_result is not None: if cache_result is not None:
sas_token = cache_result.decode('utf-8') sas_token = cache_result.decode("utf-8")
else: else:
sas_token = generate_account_sas( sas_token = generate_account_sas(
account_name=self.account_name, account_name=self.account_name,
account_key=self.account_key, account_key=self.account_key,
resource_types=ResourceTypes(service=True, container=True, object=True), resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1),
) )
redis_client.set(cache_key, sas_token, ex=3000) redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url, credential=sas_token) return BlobServiceClient(account_url=self.account_url, credential=sas_token)

View File

@ -1,4 +1,5 @@
"""Abstract interface for file storage implementations.""" """Abstract interface for file storage implementations."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
@ -6,8 +7,8 @@ from flask import Flask
class BaseStorage(ABC): class BaseStorage(ABC):
"""Interface for file storage. """Interface for file storage."""
"""
app = None app = None
def __init__(self, app: Flask): def __init__(self, app: Flask):

View File

@ -11,16 +11,16 @@ from extensions.storage.base_storage import BaseStorage
class GoogleStorage(BaseStorage): class GoogleStorage(BaseStorage):
"""Implementation for google storage. """Implementation for google storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME")
service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64")
# if service_account_json_str is empty, use Application Default Credentials # if service_account_json_str is empty, use Application Default Credentials
if service_account_json_str: if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object # convert str to object
service_account_obj = json.loads(service_account_json) service_account_obj = json.loads(service_account_json)
self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
@ -43,9 +43,10 @@ class GoogleStorage(BaseStorage):
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
bucket = self.client.get_bucket(self.bucket_name) bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename) blob = bucket.get_blob(filename)
with closing(blob.open(mode='rb')) as blob_stream: with closing(blob.open(mode="rb")) as blob_stream:
while chunk := blob_stream.read(4096): while chunk := blob_stream.read(4096):
yield chunk yield chunk
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):

View File

@ -8,21 +8,20 @@ from extensions.storage.base_storage import BaseStorage
class LocalStorage(BaseStorage): class LocalStorage(BaseStorage):
"""Implementation for local storage. """Implementation for local storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
folder = self.app.config.get('STORAGE_LOCAL_PATH') folder = self.app.config.get("STORAGE_LOCAL_PATH")
if not os.path.isabs(folder): if not os.path.isabs(folder):
folder = os.path.join(app.root_path, folder) folder = os.path.join(app.root_path, folder)
self.folder = folder self.folder = folder
def save(self, filename, data): def save(self, filename, data):
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
folder = os.path.dirname(filename) folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
@ -31,10 +30,10 @@ class LocalStorage(BaseStorage):
f.write(data) f.write(data)
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
if not os.path.exists(filename): if not os.path.exists(filename):
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -46,10 +45,10 @@ class LocalStorage(BaseStorage):
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
if not os.path.exists(filename): if not os.path.exists(filename):
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -61,10 +60,10 @@ class LocalStorage(BaseStorage):
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
if not os.path.exists(filename): if not os.path.exists(filename):
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -72,17 +71,17 @@ class LocalStorage(BaseStorage):
shutil.copyfile(filename, target_filepath) shutil.copyfile(filename, target_filepath)
def exists(self, filename): def exists(self, filename):
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
return os.path.exists(filename) return os.path.exists(filename)
def delete(self, filename): def delete(self, filename):
if not self.folder or self.folder.endswith('/'): if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename filename = self.folder + filename
else: else:
filename = self.folder + '/' + filename filename = self.folder + "/" + filename
if os.path.exists(filename): if os.path.exists(filename):
os.remove(filename) os.remove(filename)

View File

@ -12,13 +12,13 @@ class OCIStorage(BaseStorage):
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('OCI_BUCKET_NAME') self.bucket_name = app_config.get("OCI_BUCKET_NAME")
self.client = boto3.client( self.client = boto3.client(
's3', "s3",
aws_secret_access_key=app_config.get('OCI_SECRET_KEY'), aws_secret_access_key=app_config.get("OCI_SECRET_KEY"),
aws_access_key_id=app_config.get('OCI_ACCESS_KEY'), aws_access_key_id=app_config.get("OCI_ACCESS_KEY"),
endpoint_url=app_config.get('OCI_ENDPOINT'), endpoint_url=app_config.get("OCI_ENDPOINT"),
region_name=app_config.get('OCI_REGION') region_name=app_config.get("OCI_REGION"),
) )
def save(self, filename, data): def save(self, filename, data):
@ -27,9 +27,9 @@ class OCIStorage(BaseStorage):
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
try: try:
with closing(self.client) as client: with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex: except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey': if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
else: else:
raise raise
@ -40,12 +40,13 @@ class OCIStorage(BaseStorage):
try: try:
with closing(self.client) as client: with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename) response = client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].iter_chunks() yield from response["Body"].iter_chunks()
except ClientError as ex: except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey': if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
else: else:
raise raise
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):

View File

@ -10,23 +10,23 @@ from extensions.storage.base_storage import BaseStorage
class S3Storage(BaseStorage): class S3Storage(BaseStorage):
"""Implementation for s3 storage. """Implementation for s3 storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('S3_BUCKET_NAME') self.bucket_name = app_config.get("S3_BUCKET_NAME")
if app_config.get('S3_USE_AWS_MANAGED_IAM'): if app_config.get("S3_USE_AWS_MANAGED_IAM"):
session = boto3.Session() session = boto3.Session()
self.client = session.client('s3') self.client = session.client("s3")
else: else:
self.client = boto3.client( self.client = boto3.client(
's3', "s3",
aws_secret_access_key=app_config.get('S3_SECRET_KEY'), aws_secret_access_key=app_config.get("S3_SECRET_KEY"),
aws_access_key_id=app_config.get('S3_ACCESS_KEY'), aws_access_key_id=app_config.get("S3_ACCESS_KEY"),
endpoint_url=app_config.get('S3_ENDPOINT'), endpoint_url=app_config.get("S3_ENDPOINT"),
region_name=app_config.get('S3_REGION'), region_name=app_config.get("S3_REGION"),
config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')}) config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}),
) )
def save(self, filename, data): def save(self, filename, data):
@ -35,9 +35,9 @@ class S3Storage(BaseStorage):
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
try: try:
with closing(self.client) as client: with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex: except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey': if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
else: else:
raise raise
@ -48,12 +48,13 @@ class S3Storage(BaseStorage):
try: try:
with closing(self.client) as client: with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename) response = client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].iter_chunks() yield from response["Body"].iter_chunks()
except ClientError as ex: except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey': if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
else: else:
raise raise
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):

View File

@ -7,18 +7,17 @@ from extensions.storage.base_storage import BaseStorage
class TencentStorage(BaseStorage): class TencentStorage(BaseStorage):
"""Implementation for tencent cos storage. """Implementation for tencent cos storage."""
"""
def __init__(self, app: Flask): def __init__(self, app: Flask):
super().__init__(app) super().__init__(app)
app_config = self.app.config app_config = self.app.config
self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME') self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME")
config = CosConfig( config = CosConfig(
Region=app_config.get('TENCENT_COS_REGION'), Region=app_config.get("TENCENT_COS_REGION"),
SecretId=app_config.get('TENCENT_COS_SECRET_ID'), SecretId=app_config.get("TENCENT_COS_SECRET_ID"),
SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'), SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"),
Scheme=app_config.get('TENCENT_COS_SCHEME'), Scheme=app_config.get("TENCENT_COS_SCHEME"),
) )
self.client = CosS3Client(config) self.client = CosS3Client(config)
@ -26,19 +25,19 @@ class TencentStorage(BaseStorage):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read() data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data return data
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename) response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].get_stream(chunk_size=4096) yield from response["Body"].get_stream(chunk_size=4096)
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename) response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response['Body'].get_stream_to_file(target_filepath) response["Body"].get_stream_to_file(target_filepath)
def exists(self, filename): def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename) return self.client.object_exists(Bucket=self.bucket_name, Key=filename)

View File

@ -5,7 +5,7 @@ from libs.helper import TimestampField
annotation_fields = { annotation_fields = {
"id": fields.String, "id": fields.String,
"question": fields.String, "question": fields.String,
"answer": fields.Raw(attribute='content'), "answer": fields.Raw(attribute="content"),
"hit_count": fields.Integer, "hit_count": fields.Integer,
"created_at": TimestampField, "created_at": TimestampField,
# 'account': fields.Nested(simple_account_fields, allow_null=True) # 'account': fields.Nested(simple_account_fields, allow_null=True)
@ -21,8 +21,8 @@ annotation_hit_history_fields = {
"score": fields.Float, "score": fields.Float,
"question": fields.String, "question": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
"match": fields.String(attribute='annotation_question'), "match": fields.String(attribute="annotation_question"),
"response": fields.String(attribute='annotation_content') "response": fields.String(attribute="annotation_content"),
} }
annotation_hit_history_list_fields = { annotation_hit_history_list_fields = {

View File

@ -8,16 +8,16 @@ class HiddenAPIKey(fields.Raw):
api_key = obj.api_key api_key = obj.api_key
# If the length of the api_key is less than 8 characters, show the first and last characters # If the length of the api_key is less than 8 characters, show the first and last characters
if len(api_key) <= 8: if len(api_key) <= 8:
return api_key[0] + '******' + api_key[-1] return api_key[0] + "******" + api_key[-1]
# If the api_key is greater than 8 characters, show the first three and the last three characters # If the api_key is greater than 8 characters, show the first three and the last three characters
else: else:
return api_key[:3] + '******' + api_key[-3:] return api_key[:3] + "******" + api_key[-3:]
api_based_extension_fields = { api_based_extension_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'api_endpoint': fields.String, "api_endpoint": fields.String,
'api_key': HiddenAPIKey, "api_key": HiddenAPIKey,
'created_at': TimestampField "created_at": TimestampField,
} }

View File

@ -3,157 +3,153 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
app_detail_kernel_fields = { app_detail_kernel_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'description': fields.String, "description": fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'), "mode": fields.String(attribute="mode_compatible_with_agent"),
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
} }
related_app_list = { related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)), "data": fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer, "total": fields.Integer,
} }
model_config_fields = { model_config_fields = {
'opening_statement': fields.String, "opening_statement": fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'), "suggested_questions": fields.Raw(attribute="suggested_questions_list"),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), "speech_to_text": fields.Raw(attribute="speech_to_text_dict"),
'text_to_speech': fields.Raw(attribute='text_to_speech_dict'), "text_to_speech": fields.Raw(attribute="text_to_speech_dict"),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), "retriever_resource": fields.Raw(attribute="retriever_resource_dict"),
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), "annotation_reply": fields.Raw(attribute="annotation_reply_dict"),
'more_like_this': fields.Raw(attribute='more_like_this_dict'), "more_like_this": fields.Raw(attribute="more_like_this_dict"),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"),
'external_data_tools': fields.Raw(attribute='external_data_tools_list'), "external_data_tools": fields.Raw(attribute="external_data_tools_list"),
'model': fields.Raw(attribute='model_dict'), "model": fields.Raw(attribute="model_dict"),
'user_input_form': fields.Raw(attribute='user_input_form_list'), "user_input_form": fields.Raw(attribute="user_input_form_list"),
'dataset_query_variable': fields.String, "dataset_query_variable": fields.String,
'pre_prompt': fields.String, "pre_prompt": fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'), "agent_mode": fields.Raw(attribute="agent_mode_dict"),
'prompt_type': fields.String, "prompt_type": fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), "dataset_configs": fields.Raw(attribute="dataset_configs_dict"),
'file_upload': fields.Raw(attribute='file_upload_dict'), "file_upload": fields.Raw(attribute="file_upload_dict"),
'created_at': TimestampField "created_at": TimestampField,
} }
app_detail_fields = { app_detail_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'description': fields.String, "description": fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'), "mode": fields.String(attribute="mode_compatible_with_agent"),
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'enable_site': fields.Boolean, "enable_site": fields.Boolean,
'enable_api': fields.Boolean, "enable_api": fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
'tracing': fields.Raw, "tracing": fields.Raw,
'created_at': TimestampField "created_at": TimestampField,
} }
prompt_config_fields = { prompt_config_fields = {
'prompt_template': fields.String, "prompt_template": fields.String,
} }
model_config_partial_fields = { model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'), "model": fields.Raw(attribute="model_dict"),
'pre_prompt': fields.String, "pre_prompt": fields.String,
} }
tag_fields = { tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
'id': fields.String,
'name': fields.String,
'type': fields.String
}
app_partial_fields = { app_partial_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'max_active_requests': fields.Raw(), "max_active_requests": fields.Raw(),
'description': fields.String(attribute='desc_or_prompt'), "description": fields.String(attribute="desc_or_prompt"),
'mode': fields.String(attribute='mode_compatible_with_agent'), "mode": fields.String(attribute="mode_compatible_with_agent"),
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True),
'created_at': TimestampField, "created_at": TimestampField,
'tags': fields.List(fields.Nested(tag_fields)) "tags": fields.List(fields.Nested(tag_fields)),
} }
app_pagination_fields = { app_pagination_fields = {
'page': fields.Integer, "page": fields.Integer,
'limit': fields.Integer(attribute='per_page'), "limit": fields.Integer(attribute="per_page"),
'total': fields.Integer, "total": fields.Integer,
'has_more': fields.Boolean(attribute='has_next'), "has_more": fields.Boolean(attribute="has_next"),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items') "data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
} }
template_fields = { template_fields = {
'name': fields.String, "name": fields.String,
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'description': fields.String, "description": fields.String,
'mode': fields.String, "mode": fields.String,
'model_config': fields.Nested(model_config_fields), "model_config": fields.Nested(model_config_fields),
} }
template_list_fields = { template_list_fields = {
'data': fields.List(fields.Nested(template_fields)), "data": fields.List(fields.Nested(template_fields)),
} }
site_fields = { site_fields = {
'access_token': fields.String(attribute='code'), "access_token": fields.String(attribute="code"),
'code': fields.String, "code": fields.String,
'title': fields.String, "title": fields.String,
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'description': fields.String, "description": fields.String,
'default_language': fields.String, "default_language": fields.String,
'chat_color_theme': fields.String, "chat_color_theme": fields.String,
'chat_color_theme_inverted': fields.Boolean, "chat_color_theme_inverted": fields.Boolean,
'customize_domain': fields.String, "customize_domain": fields.String,
'copyright': fields.String, "copyright": fields.String,
'privacy_policy': fields.String, "privacy_policy": fields.String,
'custom_disclaimer': fields.String, "custom_disclaimer": fields.String,
'customize_token_strategy': fields.String, "customize_token_strategy": fields.String,
'prompt_public': fields.Boolean, "prompt_public": fields.Boolean,
'app_base_url': fields.String, "app_base_url": fields.String,
'show_workflow_steps': fields.Boolean, "show_workflow_steps": fields.Boolean,
} }
app_detail_fields_with_site = { app_detail_fields_with_site = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'description': fields.String, "description": fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'), "mode": fields.String(attribute="mode_compatible_with_agent"),
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'enable_site': fields.Boolean, "enable_site": fields.Boolean,
'enable_api': fields.Boolean, "enable_api": fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
'site': fields.Nested(site_fields), "site": fields.Nested(site_fields),
'api_base_url': fields.String, "api_base_url": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'deleted_tools': fields.List(fields.String), "deleted_tools": fields.List(fields.String),
} }
app_site_fields = { app_site_fields = {
'app_id': fields.String, "app_id": fields.String,
'access_token': fields.String(attribute='code'), "access_token": fields.String(attribute="code"),
'code': fields.String, "code": fields.String,
'title': fields.String, "title": fields.String,
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String, "icon_background": fields.String,
'description': fields.String, "description": fields.String,
'default_language': fields.String, "default_language": fields.String,
'customize_domain': fields.String, "customize_domain": fields.String,
'copyright': fields.String, "copyright": fields.String,
'privacy_policy': fields.String, "privacy_policy": fields.String,
'custom_disclaimer': fields.String, "custom_disclaimer": fields.String,
'customize_token_strategy': fields.String, "customize_token_strategy": fields.String,
'prompt_public': fields.Boolean, "prompt_public": fields.Boolean,
'show_workflow_steps': fields.Boolean, "show_workflow_steps": fields.Boolean,
} }

View File

@ -6,205 +6,202 @@ from libs.helper import TimestampField
class MessageTextField(fields.Raw): class MessageTextField(fields.Raw):
def format(self, value): def format(self, value):
return value[0]['text'] if value else '' return value[0]["text"] if value else ""
feedback_fields = { feedback_fields = {
'rating': fields.String, "rating": fields.String,
'content': fields.String, "content": fields.String,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_account': fields.Nested(simple_account_fields, allow_null=True), "from_account": fields.Nested(simple_account_fields, allow_null=True),
} }
annotation_fields = { annotation_fields = {
'id': fields.String, "id": fields.String,
'question': fields.String, "question": fields.String,
'content': fields.String, "content": fields.String,
'account': fields.Nested(simple_account_fields, allow_null=True), "account": fields.Nested(simple_account_fields, allow_null=True),
'created_at': TimestampField "created_at": TimestampField,
} }
annotation_hit_history_fields = { annotation_hit_history_fields = {
'annotation_id': fields.String(attribute='id'), "annotation_id": fields.String(attribute="id"),
'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
'created_at': TimestampField "created_at": TimestampField,
} }
message_file_fields = { message_file_fields = {
'id': fields.String, "id": fields.String,
'type': fields.String, "type": fields.String,
'url': fields.String, "url": fields.String,
'belongs_to': fields.String(default='user'), "belongs_to": fields.String(default="user"),
} }
agent_thought_fields = { agent_thought_fields = {
'id': fields.String, "id": fields.String,
'chain_id': fields.String, "chain_id": fields.String,
'message_id': fields.String, "message_id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'thought': fields.String, "thought": fields.String,
'tool': fields.String, "tool": fields.String,
'tool_labels': fields.Raw, "tool_labels": fields.Raw,
'tool_input': fields.String, "tool_input": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'observation': fields.String, "observation": fields.String,
'files': fields.List(fields.String), "files": fields.List(fields.String),
} }
message_detail_fields = { message_detail_fields = {
'id': fields.String, "id": fields.String,
'conversation_id': fields.String, "conversation_id": fields.String,
'inputs': fields.Raw, "inputs": fields.Raw,
'query': fields.String, "query": fields.String,
'message': fields.Raw, "message": fields.Raw,
'message_tokens': fields.Integer, "message_tokens": fields.Integer,
'answer': fields.String(attribute='re_sign_file_url_answer'), "answer": fields.String(attribute="re_sign_file_url_answer"),
'answer_tokens': fields.Integer, "answer_tokens": fields.Integer,
'provider_response_latency': fields.Float, "provider_response_latency": fields.Float,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_account_id': fields.String, "from_account_id": fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)), "feedbacks": fields.List(fields.Nested(feedback_fields)),
'workflow_run_id': fields.String, "workflow_run_id": fields.String,
'annotation': fields.Nested(annotation_fields, allow_null=True), "annotation": fields.Nested(annotation_fields, allow_null=True),
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
'created_at': TimestampField, "created_at": TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
'metadata': fields.Raw(attribute='message_metadata_dict'), "metadata": fields.Raw(attribute="message_metadata_dict"),
'status': fields.String, "status": fields.String,
'error': fields.String, "error": fields.String,
} }
feedback_stat_fields = { feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = { model_config_fields = {
'opening_statement': fields.String, "opening_statement": fields.String,
'suggested_questions': fields.Raw, "suggested_questions": fields.Raw,
'model': fields.Raw, "model": fields.Raw,
'user_input_form': fields.Raw, "user_input_form": fields.Raw,
'pre_prompt': fields.String, "pre_prompt": fields.String,
'agent_mode': fields.Raw, "agent_mode": fields.Raw,
} }
simple_configs_fields = { simple_configs_fields = {
'prompt_template': fields.String, "prompt_template": fields.String,
} }
simple_model_config_fields = { simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'), "model": fields.Raw(attribute="model_dict"),
'pre_prompt': fields.String, "pre_prompt": fields.String,
} }
simple_message_detail_fields = { simple_message_detail_fields = {
'inputs': fields.Raw, "inputs": fields.Raw,
'query': fields.String, "query": fields.String,
'message': MessageTextField, "message": MessageTextField,
'answer': fields.String, "answer": fields.String,
} }
conversation_fields = { conversation_fields = {
'id': fields.String, "id": fields.String,
'status': fields.String, "status": fields.String,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_end_user_session_id': fields.String(), "from_end_user_session_id": fields.String(),
'from_account_id': fields.String, "from_account_id": fields.String,
'read_at': TimestampField, "read_at": TimestampField,
'created_at': TimestampField, "created_at": TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True), "annotation": fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields), "model_config": fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields), "admin_feedback_stats": fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message') "message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
} }
conversation_pagination_fields = { conversation_pagination_fields = {
'page': fields.Integer, "page": fields.Integer,
'limit': fields.Integer(attribute='per_page'), "limit": fields.Integer(attribute="per_page"),
'total': fields.Integer, "total": fields.Integer,
'has_more': fields.Boolean(attribute='has_next'), "has_more": fields.Boolean(attribute="has_next"),
'data': fields.List(fields.Nested(conversation_fields), attribute='items') "data": fields.List(fields.Nested(conversation_fields), attribute="items"),
} }
conversation_message_detail_fields = { conversation_message_detail_fields = {
'id': fields.String, "id": fields.String,
'status': fields.String, "status": fields.String,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_account_id': fields.String, "from_account_id": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'model_config': fields.Nested(model_config_fields), "model_config": fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'), "message": fields.Nested(message_detail_fields, attribute="first_message"),
} }
conversation_with_summary_fields = { conversation_with_summary_fields = {
'id': fields.String, "id": fields.String,
'status': fields.String, "status": fields.String,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_end_user_session_id': fields.String, "from_end_user_session_id": fields.String,
'from_account_id': fields.String, "from_account_id": fields.String,
'name': fields.String, "name": fields.String,
'summary': fields.String(attribute='summary_or_query'), "summary": fields.String(attribute="summary_or_query"),
'read_at': TimestampField, "read_at": TimestampField,
'created_at': TimestampField, "created_at": TimestampField,
'annotated': fields.Boolean, "annotated": fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields), "model_config": fields.Nested(simple_model_config_fields),
'message_count': fields.Integer, "message_count": fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields) "admin_feedback_stats": fields.Nested(feedback_stat_fields),
} }
conversation_with_summary_pagination_fields = { conversation_with_summary_pagination_fields = {
'page': fields.Integer, "page": fields.Integer,
'limit': fields.Integer(attribute='per_page'), "limit": fields.Integer(attribute="per_page"),
'total': fields.Integer, "total": fields.Integer,
'has_more': fields.Boolean(attribute='has_next'), "has_more": fields.Boolean(attribute="has_next"),
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
} }
conversation_detail_fields = { conversation_detail_fields = {
'id': fields.String, "id": fields.String,
'status': fields.String, "status": fields.String,
'from_source': fields.String, "from_source": fields.String,
'from_end_user_id': fields.String, "from_end_user_id": fields.String,
'from_account_id': fields.String, "from_account_id": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'annotated': fields.Boolean, "annotated": fields.Boolean,
'introduction': fields.String, "introduction": fields.String,
'model_config': fields.Nested(model_config_fields), "model_config": fields.Nested(model_config_fields),
'message_count': fields.Integer, "message_count": fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields) "admin_feedback_stats": fields.Nested(feedback_stat_fields),
} }
simple_conversation_fields = { simple_conversation_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'inputs': fields.Raw, "inputs": fields.Raw,
'status': fields.String, "status": fields.String,
'introduction': fields.String, "introduction": fields.String,
'created_at': TimestampField "created_at": TimestampField,
} }
conversation_infinite_scroll_pagination_fields = { conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer, "limit": fields.Integer,
'has_more': fields.Boolean, "has_more": fields.Boolean,
'data': fields.List(fields.Nested(simple_conversation_fields)) "data": fields.List(fields.Nested(simple_conversation_fields)),
} }
conversation_with_model_config_fields = { conversation_with_model_config_fields = {
**simple_conversation_fields, **simple_conversation_fields,
'model_config': fields.Raw, "model_config": fields.Raw,
} }
conversation_with_model_config_infinite_scroll_pagination_fields = { conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer, "limit": fields.Integer,
'has_more': fields.Boolean, "has_more": fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields)) "data": fields.List(fields.Nested(conversation_with_model_config_fields)),
} }

View File

@ -3,19 +3,19 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
conversation_variable_fields = { conversation_variable_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'value_type': fields.String(attribute='value_type.value'), "value_type": fields.String(attribute="value_type.value"),
'value': fields.String, "value": fields.String,
'description': fields.String, "description": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'updated_at': TimestampField, "updated_at": TimestampField,
} }
paginated_conversation_variable_fields = { paginated_conversation_variable_fields = {
'page': fields.Integer, "page": fields.Integer,
'limit': fields.Integer, "limit": fields.Integer,
'total': fields.Integer, "total": fields.Integer,
'has_more': fields.Boolean, "has_more": fields.Boolean,
'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'), "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"),
} }

View File

@ -2,64 +2,56 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
integrate_icon_fields = { integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = { integrate_page_fields = {
'page_name': fields.String, "page_name": fields.String,
'page_id': fields.String, "page_id": fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), "page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean, "is_bound": fields.Boolean,
'parent_id': fields.String, "parent_id": fields.String,
'type': fields.String "type": fields.String,
} }
integrate_workspace_fields = { integrate_workspace_fields = {
'workspace_name': fields.String, "workspace_name": fields.String,
'workspace_id': fields.String, "workspace_id": fields.String,
'workspace_icon': fields.String, "workspace_icon": fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)) "pages": fields.List(fields.Nested(integrate_page_fields)),
} }
integrate_notion_info_list_fields = { integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), "notion_info": fields.List(fields.Nested(integrate_workspace_fields)),
} }
integrate_icon_fields = { integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = { integrate_page_fields = {
'page_name': fields.String, "page_name": fields.String,
'page_id': fields.String, "page_id": fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), "page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String, "parent_id": fields.String,
'type': fields.String "type": fields.String,
} }
integrate_workspace_fields = { integrate_workspace_fields = {
'workspace_name': fields.String, "workspace_name": fields.String,
'workspace_id': fields.String, "workspace_id": fields.String,
'workspace_icon': fields.String, "workspace_icon": fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)), "pages": fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer "total": fields.Integer,
} }
integrate_fields = { integrate_fields = {
'id': fields.String, "id": fields.String,
'provider': fields.String, "provider": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'is_bound': fields.Boolean, "is_bound": fields.Boolean,
'disabled': fields.Boolean, "disabled": fields.Boolean,
'link': fields.String, "link": fields.String,
'source_info': fields.Nested(integrate_workspace_fields) "source_info": fields.Nested(integrate_workspace_fields),
} }
integrate_list_fields = { integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)), "data": fields.List(fields.Nested(integrate_fields)),
} }

View File

@ -3,73 +3,64 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
dataset_fields = { dataset_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'description': fields.String, "description": fields.String,
'permission': fields.String, "permission": fields.String,
'data_source_type': fields.String, "data_source_type": fields.String,
'indexing_technique': fields.String, "indexing_technique": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
} }
reranking_model_fields = { reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
'reranking_provider_name': fields.String,
'reranking_model_name': fields.String
}
keyword_setting_fields = { keyword_setting_fields = {"keyword_weight": fields.Float}
'keyword_weight': fields.Float
}
vector_setting_fields = { vector_setting_fields = {
'vector_weight': fields.Float, "vector_weight": fields.Float,
'embedding_model_name': fields.String, "embedding_model_name": fields.String,
'embedding_provider_name': fields.String, "embedding_provider_name": fields.String,
} }
weighted_score_fields = { weighted_score_fields = {
'keyword_setting': fields.Nested(keyword_setting_fields), "keyword_setting": fields.Nested(keyword_setting_fields),
'vector_setting': fields.Nested(vector_setting_fields), "vector_setting": fields.Nested(vector_setting_fields),
} }
dataset_retrieval_model_fields = { dataset_retrieval_model_fields = {
'search_method': fields.String, "search_method": fields.String,
'reranking_enable': fields.Boolean, "reranking_enable": fields.Boolean,
'reranking_mode': fields.String, "reranking_mode": fields.String,
'reranking_model': fields.Nested(reranking_model_fields), "reranking_model": fields.Nested(reranking_model_fields),
'weights': fields.Nested(weighted_score_fields, allow_null=True), "weights": fields.Nested(weighted_score_fields, allow_null=True),
'top_k': fields.Integer, "top_k": fields.Integer,
'score_threshold_enabled': fields.Boolean, "score_threshold_enabled": fields.Boolean,
'score_threshold': fields.Float "score_threshold": fields.Float,
} }
tag_fields = { tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
'id': fields.String,
'name': fields.String,
'type': fields.String
}
dataset_detail_fields = { dataset_detail_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'description': fields.String, "description": fields.String,
'provider': fields.String, "provider": fields.String,
'permission': fields.String, "permission": fields.String,
'data_source_type': fields.String, "data_source_type": fields.String,
'indexing_technique': fields.String, "indexing_technique": fields.String,
'app_count': fields.Integer, "app_count": fields.Integer,
'document_count': fields.Integer, "document_count": fields.Integer,
'word_count': fields.Integer, "word_count": fields.Integer,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'updated_by': fields.String, "updated_by": fields.String,
'updated_at': TimestampField, "updated_at": TimestampField,
'embedding_model': fields.String, "embedding_model": fields.String,
'embedding_model_provider': fields.String, "embedding_model_provider": fields.String,
'embedding_available': fields.Boolean, "embedding_available": fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields), "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
'tags': fields.List(fields.Nested(tag_fields)) "tags": fields.List(fields.Nested(tag_fields)),
} }
dataset_query_detail_fields = { dataset_query_detail_fields = {
@ -79,7 +70,5 @@ dataset_query_detail_fields = {
"source_app_id": fields.String, "source_app_id": fields.String,
"created_by_role": fields.String, "created_by_role": fields.String,
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField "created_at": TimestampField,
} }

View File

@ -4,75 +4,73 @@ from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField from libs.helper import TimestampField
document_fields = { document_fields = {
'id': fields.String, "id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'data_source_type': fields.String, "data_source_type": fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'), "data_source_info": fields.Raw(attribute="data_source_info_dict"),
'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
'dataset_process_rule_id': fields.String, "dataset_process_rule_id": fields.String,
'name': fields.String, "name": fields.String,
'created_from': fields.String, "created_from": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'tokens': fields.Integer, "tokens": fields.Integer,
'indexing_status': fields.String, "indexing_status": fields.String,
'error': fields.String, "error": fields.String,
'enabled': fields.Boolean, "enabled": fields.Boolean,
'disabled_at': TimestampField, "disabled_at": TimestampField,
'disabled_by': fields.String, "disabled_by": fields.String,
'archived': fields.Boolean, "archived": fields.Boolean,
'display_status': fields.String, "display_status": fields.String,
'word_count': fields.Integer, "word_count": fields.Integer,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'doc_form': fields.String, "doc_form": fields.String,
} }
document_with_segments_fields = { document_with_segments_fields = {
'id': fields.String, "id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'data_source_type': fields.String, "data_source_type": fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'), "data_source_info": fields.Raw(attribute="data_source_info_dict"),
'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
'dataset_process_rule_id': fields.String, "dataset_process_rule_id": fields.String,
'name': fields.String, "name": fields.String,
'created_from': fields.String, "created_from": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'tokens': fields.Integer, "tokens": fields.Integer,
'indexing_status': fields.String, "indexing_status": fields.String,
'error': fields.String, "error": fields.String,
'enabled': fields.Boolean, "enabled": fields.Boolean,
'disabled_at': TimestampField, "disabled_at": TimestampField,
'disabled_by': fields.String, "disabled_by": fields.String,
'archived': fields.Boolean, "archived": fields.Boolean,
'display_status': fields.String, "display_status": fields.String,
'word_count': fields.Integer, "word_count": fields.Integer,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'completed_segments': fields.Integer, "completed_segments": fields.Integer,
'total_segments': fields.Integer "total_segments": fields.Integer,
} }
dataset_and_document_fields = { dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields), "dataset": fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)), "documents": fields.List(fields.Nested(document_fields)),
'batch': fields.String "batch": fields.String,
} }
document_status_fields = { document_status_fields = {
'id': fields.String, "id": fields.String,
'indexing_status': fields.String, "indexing_status": fields.String,
'processing_started_at': TimestampField, "processing_started_at": TimestampField,
'parsing_completed_at': TimestampField, "parsing_completed_at": TimestampField,
'cleaning_completed_at': TimestampField, "cleaning_completed_at": TimestampField,
'splitting_completed_at': TimestampField, "splitting_completed_at": TimestampField,
'completed_at': TimestampField, "completed_at": TimestampField,
'paused_at': TimestampField, "paused_at": TimestampField,
'error': fields.String, "error": fields.String,
'stopped_at': TimestampField, "stopped_at": TimestampField,
'completed_segments': fields.Integer, "completed_segments": fields.Integer,
'total_segments': fields.Integer, "total_segments": fields.Integer,
} }
document_status_fields_list = { document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))}
'data': fields.List(fields.Nested(document_status_fields))
}

View File

@ -1,8 +1,8 @@
from flask_restful import fields from flask_restful import fields
simple_end_user_fields = { simple_end_user_fields = {
'id': fields.String, "id": fields.String,
'type': fields.String, "type": fields.String,
'is_anonymous': fields.Boolean, "is_anonymous": fields.Boolean,
'session_id': fields.String, "session_id": fields.String,
} }

View File

@ -3,17 +3,17 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
upload_config_fields = { upload_config_fields = {
'file_size_limit': fields.Integer, "file_size_limit": fields.Integer,
'batch_count_limit': fields.Integer, "batch_count_limit": fields.Integer,
'image_file_size_limit': fields.Integer, "image_file_size_limit": fields.Integer,
} }
file_fields = { file_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'size': fields.Integer, "size": fields.Integer,
'extension': fields.String, "extension": fields.String,
'mime_type': fields.String, "mime_type": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
} }

View File

@ -3,39 +3,39 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
document_fields = { document_fields = {
'id': fields.String, "id": fields.String,
'data_source_type': fields.String, "data_source_type": fields.String,
'name': fields.String, "name": fields.String,
'doc_type': fields.String, "doc_type": fields.String,
} }
segment_fields = { segment_fields = {
'id': fields.String, "id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'document_id': fields.String, "document_id": fields.String,
'content': fields.String, "content": fields.String,
'answer': fields.String, "answer": fields.String,
'word_count': fields.Integer, "word_count": fields.Integer,
'tokens': fields.Integer, "tokens": fields.Integer,
'keywords': fields.List(fields.String), "keywords": fields.List(fields.String),
'index_node_id': fields.String, "index_node_id": fields.String,
'index_node_hash': fields.String, "index_node_hash": fields.String,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'enabled': fields.Boolean, "enabled": fields.Boolean,
'disabled_at': TimestampField, "disabled_at": TimestampField,
'disabled_by': fields.String, "disabled_by": fields.String,
'status': fields.String, "status": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'indexing_at': TimestampField, "indexing_at": TimestampField,
'completed_at': TimestampField, "completed_at": TimestampField,
'error': fields.String, "error": fields.String,
'stopped_at': TimestampField, "stopped_at": TimestampField,
'document': fields.Nested(document_fields), "document": fields.Nested(document_fields),
} }
hit_testing_record_fields = { hit_testing_record_fields = {
'segment': fields.Nested(segment_fields), "segment": fields.Nested(segment_fields),
'score': fields.Float, "score": fields.Float,
'tsne_position': fields.Raw "tsne_position": fields.Raw,
} }

View File

@ -3,23 +3,21 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
app_fields = { app_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'mode': fields.String, "mode": fields.String,
'icon': fields.String, "icon": fields.String,
'icon_background': fields.String "icon_background": fields.String,
} }
installed_app_fields = { installed_app_fields = {
'id': fields.String, "id": fields.String,
'app': fields.Nested(app_fields), "app": fields.Nested(app_fields),
'app_owner_tenant_id': fields.String, "app_owner_tenant_id": fields.String,
'is_pinned': fields.Boolean, "is_pinned": fields.Boolean,
'last_used_at': TimestampField, "last_used_at": TimestampField,
'editable': fields.Boolean, "editable": fields.Boolean,
'uninstallable': fields.Boolean "uninstallable": fields.Boolean,
} }
installed_app_list_fields = { installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))}
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}

View File

@ -2,38 +2,32 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
simple_account_fields = { simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
'id': fields.String,
'name': fields.String,
'email': fields.String
}
account_fields = { account_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'avatar': fields.String, "avatar": fields.String,
'email': fields.String, "email": fields.String,
'is_password_set': fields.Boolean, "is_password_set": fields.Boolean,
'interface_language': fields.String, "interface_language": fields.String,
'interface_theme': fields.String, "interface_theme": fields.String,
'timezone': fields.String, "timezone": fields.String,
'last_login_at': TimestampField, "last_login_at": TimestampField,
'last_login_ip': fields.String, "last_login_ip": fields.String,
'created_at': TimestampField "created_at": TimestampField,
} }
account_with_role_fields = { account_with_role_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'avatar': fields.String, "avatar": fields.String,
'email': fields.String, "email": fields.String,
'last_login_at': TimestampField, "last_login_at": TimestampField,
'last_active_at': TimestampField, "last_active_at": TimestampField,
'created_at': TimestampField, "created_at": TimestampField,
'role': fields.String, "role": fields.String,
'status': fields.String, "status": fields.String,
} }
account_with_role_list_fields = { account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
'accounts': fields.List(fields.Nested(account_with_role_fields))
}

View File

@ -3,83 +3,79 @@ from flask_restful import fields
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField from libs.helper import TimestampField
feedback_fields = { feedback_fields = {"rating": fields.String}
'rating': fields.String
}
retriever_resource_fields = { retriever_resource_fields = {
'id': fields.String, "id": fields.String,
'message_id': fields.String, "message_id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'dataset_id': fields.String, "dataset_id": fields.String,
'dataset_name': fields.String, "dataset_name": fields.String,
'document_id': fields.String, "document_id": fields.String,
'document_name': fields.String, "document_name": fields.String,
'data_source_type': fields.String, "data_source_type": fields.String,
'segment_id': fields.String, "segment_id": fields.String,
'score': fields.Float, "score": fields.Float,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'word_count': fields.Integer, "word_count": fields.Integer,
'segment_position': fields.Integer, "segment_position": fields.Integer,
'index_node_hash': fields.String, "index_node_hash": fields.String,
'content': fields.String, "content": fields.String,
'created_at': TimestampField "created_at": TimestampField,
} }
feedback_fields = { feedback_fields = {"rating": fields.String}
'rating': fields.String
}
agent_thought_fields = { agent_thought_fields = {
'id': fields.String, "id": fields.String,
'chain_id': fields.String, "chain_id": fields.String,
'message_id': fields.String, "message_id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'thought': fields.String, "thought": fields.String,
'tool': fields.String, "tool": fields.String,
'tool_labels': fields.Raw, "tool_labels": fields.Raw,
'tool_input': fields.String, "tool_input": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'observation': fields.String, "observation": fields.String,
'files': fields.List(fields.String) "files": fields.List(fields.String),
} }
retriever_resource_fields = { retriever_resource_fields = {
'id': fields.String, "id": fields.String,
'message_id': fields.String, "message_id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'dataset_id': fields.String, "dataset_id": fields.String,
'dataset_name': fields.String, "dataset_name": fields.String,
'document_id': fields.String, "document_id": fields.String,
'document_name': fields.String, "document_name": fields.String,
'data_source_type': fields.String, "data_source_type": fields.String,
'segment_id': fields.String, "segment_id": fields.String,
'score': fields.Float, "score": fields.Float,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'word_count': fields.Integer, "word_count": fields.Integer,
'segment_position': fields.Integer, "segment_position": fields.Integer,
'index_node_hash': fields.String, "index_node_hash": fields.String,
'content': fields.String, "content": fields.String,
'created_at': TimestampField "created_at": TimestampField,
} }
message_fields = { message_fields = {
'id': fields.String, "id": fields.String,
'conversation_id': fields.String, "conversation_id": fields.String,
'inputs': fields.Raw, "inputs": fields.Raw,
'query': fields.String, "query": fields.String,
'answer': fields.String(attribute='re_sign_file_url_answer'), "answer": fields.String(attribute="re_sign_file_url_answer"),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField, "created_at": TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
'status': fields.String, "status": fields.String,
'error': fields.String, "error": fields.String,
} }
message_infinite_scroll_pagination_fields = { message_infinite_scroll_pagination_fields = {
'limit': fields.Integer, "limit": fields.Integer,
'has_more': fields.Boolean, "has_more": fields.Boolean,
'data': fields.List(fields.Nested(message_fields)) "data": fields.List(fields.Nested(message_fields)),
} }

View File

@ -3,31 +3,31 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
segment_fields = { segment_fields = {
'id': fields.String, "id": fields.String,
'position': fields.Integer, "position": fields.Integer,
'document_id': fields.String, "document_id": fields.String,
'content': fields.String, "content": fields.String,
'answer': fields.String, "answer": fields.String,
'word_count': fields.Integer, "word_count": fields.Integer,
'tokens': fields.Integer, "tokens": fields.Integer,
'keywords': fields.List(fields.String), "keywords": fields.List(fields.String),
'index_node_id': fields.String, "index_node_id": fields.String,
'index_node_hash': fields.String, "index_node_hash": fields.String,
'hit_count': fields.Integer, "hit_count": fields.Integer,
'enabled': fields.Boolean, "enabled": fields.Boolean,
'disabled_at': TimestampField, "disabled_at": TimestampField,
'disabled_by': fields.String, "disabled_by": fields.String,
'status': fields.String, "status": fields.String,
'created_by': fields.String, "created_by": fields.String,
'created_at': TimestampField, "created_at": TimestampField,
'indexing_at': TimestampField, "indexing_at": TimestampField,
'completed_at': TimestampField, "completed_at": TimestampField,
'error': fields.String, "error": fields.String,
'stopped_at': TimestampField "stopped_at": TimestampField,
} }
segment_list_response = { segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)), "data": fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean, "has_more": fields.Boolean,
'limit': fields.Integer "limit": fields.Integer,
} }

View File

@ -1,8 +1,3 @@
from flask_restful import fields from flask_restful import fields
tag_fields = { tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}
'id': fields.String,
'name': fields.String,
'type': fields.String,
'binding_count': fields.String
}

View File

@ -7,18 +7,18 @@ from libs.helper import TimestampField
workflow_app_log_partial_fields = { workflow_app_log_partial_fields = {
"id": fields.String, "id": fields.String,
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
"created_from": fields.String, "created_from": fields.String,
"created_by_role": fields.String, "created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField "created_at": TimestampField,
} }
workflow_app_log_pagination_fields = { workflow_app_log_pagination_fields = {
'page': fields.Integer, "page": fields.Integer,
'limit': fields.Integer(attribute='per_page'), "limit": fields.Integer(attribute="per_page"),
'total': fields.Integer, "total": fields.Integer,
'has_more': fields.Boolean(attribute='has_next'), "has_more": fields.Boolean(attribute="has_next"),
'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"),
} }

View File

@ -13,43 +13,43 @@ class EnvironmentVariableField(fields.Raw):
# Mask secret variables values in environment_variables # Mask secret variables values in environment_variables
if isinstance(value, SecretVariable): if isinstance(value, SecretVariable):
return { return {
'id': value.id, "id": value.id,
'name': value.name, "name": value.name,
'value': encrypter.obfuscated_token(value.value), "value": encrypter.obfuscated_token(value.value),
'value_type': value.value_type.value, "value_type": value.value_type.value,
} }
if isinstance(value, Variable): if isinstance(value, Variable):
return { return {
'id': value.id, "id": value.id,
'name': value.name, "name": value.name,
'value': value.value, "value": value.value,
'value_type': value.value_type.value, "value_type": value.value_type.value,
} }
if isinstance(value, dict): if isinstance(value, dict):
value_type = value.get('value_type') value_type = value.get("value_type")
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f'Unsupported environment variable value type: {value_type}') raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value return value
conversation_variable_fields = { conversation_variable_fields = {
'id': fields.String, "id": fields.String,
'name': fields.String, "name": fields.String,
'value_type': fields.String(attribute='value_type.value'), "value_type": fields.String(attribute="value_type.value"),
'value': fields.Raw, "value": fields.Raw,
'description': fields.String, "description": fields.String,
} }
workflow_fields = { workflow_fields = {
'id': fields.String, "id": fields.String,
'graph': fields.Raw(attribute='graph_dict'), "graph": fields.Raw(attribute="graph_dict"),
'features': fields.Raw(attribute='features_dict'), "features": fields.Raw(attribute="features_dict"),
'hash': fields.String(attribute='unique_hash'), "hash": fields.String(attribute="unique_hash"),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
'created_at': TimestampField, "created_at": TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
'updated_at': TimestampField, "updated_at": TimestampField,
'tool_published': fields.Boolean, "tool_published": fields.Boolean,
'environment_variables': fields.List(EnvironmentVariableField()), "environment_variables": fields.List(EnvironmentVariableField()),
'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
} }

View File

@ -13,7 +13,7 @@ workflow_run_for_log_fields = {
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField "finished_at": TimestampField,
} }
workflow_run_for_list_fields = { workflow_run_for_list_fields = {
@ -24,9 +24,9 @@ workflow_run_for_list_fields = {
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField "finished_at": TimestampField,
} }
advanced_chat_workflow_run_for_list_fields = { advanced_chat_workflow_run_for_list_fields = {
@ -39,40 +39,40 @@ advanced_chat_workflow_run_for_list_fields = {
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField "finished_at": TimestampField,
} }
advanced_chat_workflow_run_pagination_fields = { advanced_chat_workflow_run_pagination_fields = {
'limit': fields.Integer(attribute='limit'), "limit": fields.Integer(attribute="limit"),
'has_more': fields.Boolean(attribute='has_more'), "has_more": fields.Boolean(attribute="has_more"),
'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data') "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"),
} }
workflow_run_pagination_fields = { workflow_run_pagination_fields = {
'limit': fields.Integer(attribute='limit'), "limit": fields.Integer(attribute="limit"),
'has_more': fields.Boolean(attribute='has_more'), "has_more": fields.Boolean(attribute="has_more"),
'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"),
} }
workflow_run_detail_fields = { workflow_run_detail_fields = {
"id": fields.String, "id": fields.String,
"sequence_number": fields.Integer, "sequence_number": fields.Integer,
"version": fields.String, "version": fields.String,
"graph": fields.Raw(attribute='graph_dict'), "graph": fields.Raw(attribute="graph_dict"),
"inputs": fields.Raw(attribute='inputs_dict'), "inputs": fields.Raw(attribute="inputs_dict"),
"status": fields.String, "status": fields.String,
"outputs": fields.Raw(attribute='outputs_dict'), "outputs": fields.Raw(attribute="outputs_dict"),
"error": fields.String, "error": fields.String,
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_by_role": fields.String, "created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField "finished_at": TimestampField,
} }
workflow_run_node_execution_fields = { workflow_run_node_execution_fields = {
@ -82,21 +82,21 @@ workflow_run_node_execution_fields = {
"node_id": fields.String, "node_id": fields.String,
"node_type": fields.String, "node_type": fields.String,
"title": fields.String, "title": fields.String,
"inputs": fields.Raw(attribute='inputs_dict'), "inputs": fields.Raw(attribute="inputs_dict"),
"process_data": fields.Raw(attribute='process_data_dict'), "process_data": fields.Raw(attribute="process_data_dict"),
"outputs": fields.Raw(attribute='outputs_dict'), "outputs": fields.Raw(attribute="outputs_dict"),
"status": fields.String, "status": fields.String,
"error": fields.String, "error": fields.String,
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"execution_metadata": fields.Raw(attribute='execution_metadata_dict'), "execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
"extras": fields.Raw, "extras": fields.Raw,
"created_at": TimestampField, "created_at": TimestampField,
"created_by_role": fields.String, "created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField "finished_at": TimestampField,
} }
workflow_run_node_execution_list_fields = { workflow_run_node_execution_list_fields = {
'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), "data": fields.List(fields.Nested(workflow_run_node_execution_fields)),
} }

View File

@ -69,7 +69,18 @@ ignore = [
] ]
[tool.ruff.format] [tool.ruff.format]
quote-style = "single" exclude = [
"core/**/*.py",
"controllers/**/*.py",
"models/**/*.py",
"utils/**/*.py",
"migrations/**/*",
"services/**/*.py",
"tasks/**/*.py",
"tests/**/*.py",
"libs/**/*.py",
"configs/**/*.py",
]
[tool.pytest_env] [tool.pytest_env]
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"

View File

@ -11,27 +11,32 @@ from extensions.ext_database import db
from models.dataset import Embedding from models.dataset import Embedding
@app.celery.task(queue='dataset') @app.celery.task(queue="dataset")
def clean_embedding_cache_task(): def clean_embedding_cache_task():
click.echo(click.style('Start clean embedding cache.', fg='green')) click.echo(click.style("Start clean embedding cache.", fg="green"))
clean_days = int(dify_config.CLEAN_DAY_SETTING) clean_days = int(dify_config.CLEAN_DAY_SETTING)
start_at = time.perf_counter() start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True: while True:
try: try:
embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ embedding_ids = (
.order_by(Embedding.created_at.desc()).limit(100).all() db.session.query(Embedding.id)
.filter(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
except NotFound: except NotFound:
break break
if embedding_ids: if embedding_ids:
for embedding_id in embedding_ids: for embedding_id in embedding_ids:
db.session.execute(text( db.session.execute(
"DELETE FROM embeddings WHERE id = :embedding_id" text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id}
), {'embedding_id': embedding_id}) )
db.session.commit() db.session.commit()
else: else:
break break
end_at = time.perf_counter() end_at = time.perf_counter()
click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green"))

View File

@ -12,9 +12,9 @@ from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetQuery, Document
@app.celery.task(queue='dataset') @app.celery.task(queue="dataset")
def clean_unused_datasets_task(): def clean_unused_datasets_task():
click.echo(click.style('Start clean unused datasets indexes.', fg='green')) click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
clean_days = dify_config.CLEAN_DAY_SETTING clean_days = dify_config.CLEAN_DAY_SETTING
start_at = time.perf_counter() start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
@ -22,40 +22,44 @@ def clean_unused_datasets_task():
while True: while True:
try: try:
# Subquery for counting new documents # Subquery for counting new documents
document_subquery_new = db.session.query( document_subquery_new = (
Document.dataset_id, db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
func.count(Document.id).label('document_count') .filter(
).filter( Document.indexing_status == "completed",
Document.indexing_status == 'completed',
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
Document.updated_at > thirty_days_ago Document.updated_at > thirty_days_ago,
).group_by(Document.dataset_id).subquery() )
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents # Subquery for counting old documents
document_subquery_old = db.session.query( document_subquery_old = (
Document.dataset_id, db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
func.count(Document.id).label('document_count') .filter(
).filter( Document.indexing_status == "completed",
Document.indexing_status == 'completed',
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
Document.updated_at < thirty_days_ago Document.updated_at < thirty_days_ago,
).group_by(Document.dataset_id).subquery() )
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter # Main query with join and filter
datasets = (db.session.query(Dataset) datasets = (
.outerjoin( db.session.query(Dataset)
document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
).outerjoin( .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id .filter(
).filter(
Dataset.created_at < thirty_days_ago, Dataset.created_at < thirty_days_ago,
func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0 func.coalesce(document_subquery_old.c.document_count, 0) > 0,
).order_by( )
Dataset.created_at.desc() .order_by(Dataset.created_at.desc())
).paginate(page=page, per_page=50)) .paginate(page=page, per_page=50)
)
except NotFound: except NotFound:
break break
@ -63,10 +67,11 @@ def clean_unused_datasets_task():
break break
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
dataset_query = db.session.query(DatasetQuery).filter( dataset_query = (
DatasetQuery.created_at > thirty_days_ago, db.session.query(DatasetQuery)
DatasetQuery.dataset_id == dataset.id .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id)
).all() .all()
)
if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:
try: try:
# remove index # remove index
@ -74,17 +79,14 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None) index_processor.clean(dataset, None)
# update document # update document
update_params = { update_params = {Document.enabled: False}
Document.enabled: False
}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
fg='green'))
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
fg='red')) )
end_at = time.perf_counter() end_at = time.perf_counter()
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green"))

View File

@ -11,5 +11,8 @@ fi
# run ruff linter # run ruff linter
ruff check --fix ./api ruff check --fix ./api
# run ruff formatter
ruff format ./api
# run dotenv-linter linter # run dotenv-linter linter
dotenv-linter ./api/.env.example ./web/.env.example dotenv-linter ./api/.env.example ./web/.env.example