diff --git a/api/app.py b/api/app.py index a3efabf06c..ed214bde97 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,7 @@ import os +from configs import dify_config + if os.environ.get("DEBUG", "false").lower() != "true": from gevent import monkey @@ -36,17 +38,11 @@ if hasattr(time, "tzset"): time.tzset() -# ------------- -# Configuration -# ------------- -config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first - - # create app app = create_app() celery = app.extensions["celery"] -if app.config.get("TESTING"): +if dify_config.TESTING: print("App is running in TESTING mode") @@ -54,15 +50,15 @@ if app.config.get("TESTING"): def after_request(response): """Add Version headers to the response.""" response.set_cookie("remember_token", "", expires=0) - response.headers.add("X-Version", app.config["CURRENT_VERSION"]) - response.headers.add("X-Env", app.config["DEPLOY_ENV"]) + response.headers.add("X-Version", dify_config.CURRENT_VERSION) + response.headers.add("X-Env", dify_config.DEPLOY_ENV) return response @app.route("/health") def health(): return Response( - json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), status=200, content_type="application/json", ) diff --git a/api/app_factory.py b/api/app_factory.py index b7bfe947f5..aba78ccab8 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -68,7 +68,7 @@ def create_flask_app_with_configs() -> Flask: def create_app() -> Flask: app = create_flask_app_with_configs() - app.secret_key = app.config["SECRET_KEY"] + app.secret_key = dify_config.SECRET_KEY initialize_extensions(app) register_blueprints(app) register_commands(app) @@ -150,7 +150,7 @@ def register_blueprints(app): CORS( web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, allow_headers=["Content-Type", "Authorization", "X-App-Code"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -161,7 +161,7 @@ def register_blueprints(app): CORS( console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, allow_headers=["Content-Type", "Authorization"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 307bc94a79..8f87356934 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -32,6 +32,21 @@ class SecurityConfig(BaseSettings): default=5, ) + LOGIN_DISABLED: bool = Field( + description="Whether to disable login checks", + default=False, + ) + + ADMIN_API_KEY_ENABLE: bool = Field( + description="Whether to enable admin api key for authentication", + default=False, + ) + + ADMIN_API_KEY: Optional[str] = Field( + description="admin api key for authentication", + default=None, + ) + class AppExecutionConfig(BaseSettings): """ diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index f78ea9b288..a70c4a31c7 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,10 +1,10 @@ -import os from functools import wraps from flask import request from flask_restful import Resource, reqparse from werkzeug.exceptions import NotFound, Unauthorized +from configs import dify_config from constants.languages import supported_language from controllers.console import api from controllers.console.wraps import only_edition_cloud @@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv("ADMIN_API_KEY"): + if not dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") auth_header = request.headers.get("Authorization") @@ -31,7 +31,7 @@ def admin_required(view): if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv("ADMIN_API_KEY") != auth_token: + if dify_config.ADMIN_API_KEY != auth_token: raise Unauthorized("API key is invalid.") return view(*args, **kwargs) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index eeeccc2349..b47ba67f2f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,8 +1,9 @@ from typing import Optional -from flask import Config, Flask +from flask import Flask from pydantic import BaseModel +from configs import dify_config from core.entities.provider_entities import QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType @@ -44,32 +45,30 @@ class HostingConfiguration: moderation_config: HostedModerationConfig = None def init_app(self, app: Flask) -> None: - config = app.config - - if config.get("EDITION") != "CLOUD": + if dify_config.EDITION != "CLOUD": return - self.provider_map["azure_openai"] = self.init_azure_openai(config) - self.provider_map["openai"] = self.init_openai(config) - self.provider_map["anthropic"] = self.init_anthropic(config) - self.provider_map["minimax"] = self.init_minimax(config) - self.provider_map["spark"] = self.init_spark(config) - self.provider_map["zhipuai"] = self.init_zhipuai(config) + self.provider_map["azure_openai"] = self.init_azure_openai() + self.provider_map["openai"] = self.init_openai() + self.provider_map["anthropic"] = self.init_anthropic() + self.provider_map["minimax"] = self.init_minimax() + self.provider_map["spark"] = self.init_spark() + self.provider_map["zhipuai"] = self.init_zhipuai() - self.moderation_config = self.init_moderation_config(config) + self.moderation_config = self.init_moderation_config() @staticmethod - def init_azure_openai(app_config: Config) -> HostingProvider: + def init_azure_openai() -> HostingProvider: quota_unit = QuotaUnit.TIMES - if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): + if dify_config.HOSTED_AZURE_OPENAI_ENABLED: credentials = { - "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), - "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), + "openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY, + "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE, "base_model_name": "gpt-35-turbo", } quotas = [] - hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) + hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, restrict_models=[ @@ -122,31 +121,31 @@ class HostingConfiguration: quota_unit=quota_unit, ) - def init_openai(self, app_config: Config) -> HostingProvider: + def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS quotas = [] - if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) - trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") + if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: + hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT + trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) - if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): - paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") + if dify_config.HOSTED_OPENAI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS") paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: credentials = { - "openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"), + "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY, } - if app_config.get("HOSTED_OPENAI_API_BASE"): - credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE") + if dify_config.HOSTED_OPENAI_API_BASE: + credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE - if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): - credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") + if dify_config.HOSTED_OPENAI_API_ORGANIZATION: + credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) @@ -156,26 +155,26 @@ class HostingConfiguration: ) @staticmethod - def init_anthropic(app_config: Config) -> HostingProvider: + def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas = [] - if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) + if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: + hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) - if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): + if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: paid_quota = PaidHostingQuota() quotas.append(paid_quota) if len(quotas) > 0: credentials = { - "anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"), + "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY, } - if app_config.get("HOSTED_ANTHROPIC_API_BASE"): - credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") + if dify_config.HOSTED_ANTHROPIC_API_BASE: + credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) @@ -185,9 +184,9 @@ class HostingConfiguration: ) @staticmethod - def init_minimax(app_config: Config) -> HostingProvider: + def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_MINIMAX_ENABLED"): + if dify_config.HOSTED_MINIMAX_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( @@ -203,9 +202,9 @@ class HostingConfiguration: ) @staticmethod - def init_spark(app_config: Config) -> HostingProvider: + def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_SPARK_ENABLED"): + if dify_config.HOSTED_SPARK_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( @@ -221,9 +220,9 @@ class HostingConfiguration: ) @staticmethod - def init_zhipuai(app_config: Config) -> HostingProvider: + def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_ZHIPUAI_ENABLED"): + if dify_config.HOSTED_ZHIPUAI_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( @@ -239,17 +238,15 @@ class HostingConfiguration: ) @staticmethod - def init_moderation_config(app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): - return HostedModerationConfig( - enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") - ) + def init_moderation_config() -> HostedModerationConfig: + if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: + return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) return HostedModerationConfig(enabled=False) @staticmethod - def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: - models_str = app_config.get(env_var) + def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]: + models_str = dify_config.model_dump().get(env_var) models_list = models_str.split(",") if models_str else [] return [ RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 69d2aa4f76..3811458e02 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -428,14 +428,13 @@ class QdrantVectorFactory(AbstractVectorFactory): if not dataset.index_struct_dict: dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) - config = current_app.config return QdrantVector( collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, - root_path=config.root_path, + root_path=current_app.config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 0ff9f90847..b9b019373d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -3,6 +3,8 @@ from datetime import timedelta from celery import Celery, Task from flask import Flask +from configs import dify_config + def init_app(app: Flask) -> Celery: class FlaskTask(Task): @@ -12,19 +14,19 @@ def init_app(app: Flask) -> Celery: broker_transport_options = {} - if app.config.get("CELERY_USE_SENTINEL"): + if dify_config.CELERY_USE_SENTINEL: broker_transport_options = { - "master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"), + "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, "sentinel_kwargs": { - "socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1), + "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, }, } celery_app = Celery( app.name, task_cls=FlaskTask, - broker=app.config.get("CELERY_BROKER_URL"), - backend=app.config.get("CELERY_BACKEND"), + broker=dify_config.CELERY_BROKER_URL, + backend=dify_config.CELERY_BACKEND, task_ignore_result=True, ) @@ -37,12 +39,12 @@ def init_app(app: Flask) -> Celery: } celery_app.conf.update( - result_backend=app.config.get("CELERY_RESULT_BACKEND"), + result_backend=dify_config.CELERY_RESULT_BACKEND, broker_transport_options=broker_transport_options, broker_connection_retry_on_startup=True, ) - if app.config.get("BROKER_USE_SSL"): + if dify_config.BROKER_USE_SSL: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) @@ -54,7 +56,7 @@ def init_app(app: Flask) -> Celery: "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", ] - day = app.config.get("CELERY_BEAT_SCHEDULER_TIME") + day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { "clean_embedding_cache_task": { "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 38e67749fc..a6de28597b 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -1,8 +1,10 @@ from flask import Flask +from configs import dify_config + def init_app(app: Flask): - if app.config.get("API_COMPRESSION_ENABLED"): + if dify_config.API_COMPRESSION_ENABLED: from flask_compress import Compress app.config["COMPRESS_MIMETYPES"] = [ diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 22f868881f..9e1a241b67 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -5,10 +5,12 @@ from logging.handlers import RotatingFileHandler from flask import Flask +from configs import dify_config + def init_app(app: Flask): log_handlers = None - log_file = app.config.get("LOG_FILE") + log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) @@ -22,13 +24,13 @@ def init_app(app: Flask): ] logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config.get("LOG_FORMAT"), - datefmt=app.config.get("LOG_DATEFORMAT"), + level=dify_config.LOG_LEVEL, + format=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, handlers=log_handlers, force=True, ) - log_tz = app.config.get("LOG_TZ") + log_tz = dify_config.LOG_TZ if log_tz: from datetime import datetime diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index b435294abc..5c5b331d8a 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -4,6 +4,8 @@ from typing import Optional import resend from flask import Flask +from configs import dify_config + class Mail: def __init__(self): @@ -14,41 +16,44 @@ class Mail: return self._client is not None def init_app(self, app: Flask): - if app.config.get("MAIL_TYPE"): - if app.config.get("MAIL_DEFAULT_SEND_FROM"): - self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") + mail_type = dify_config.MAIL_TYPE + if not mail_type: + logging.warning("MAIL_TYPE is not set") + return - if app.config.get("MAIL_TYPE") == "resend": - api_key = app.config.get("RESEND_API_KEY") + if dify_config.MAIL_DEFAULT_SEND_FROM: + self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM + + match mail_type: + case "resend": + api_key = dify_config.RESEND_API_KEY if not api_key: raise ValueError("RESEND_API_KEY is not set") - api_url = app.config.get("RESEND_API_URL") + api_url = dify_config.RESEND_API_URL if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get("MAIL_TYPE") == "smtp": + case "smtp": from libs.smtp import SMTPClient - if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT: raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") - if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS: raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get("SMTP_SERVER"), - port=app.config.get("SMTP_PORT"), - username=app.config.get("SMTP_USERNAME"), - password=app.config.get("SMTP_PASSWORD"), - _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), - use_tls=app.config.get("SMTP_USE_TLS"), - opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), + server=dify_config.SMTP_SERVER, + port=dify_config.SMTP_PORT, + username=dify_config.SMTP_USERNAME, + password=dify_config.SMTP_PASSWORD, + _from=dify_config.MAIL_DEFAULT_SEND_FROM, + use_tls=dify_config.SMTP_USE_TLS, + opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) - else: - raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) - else: - logging.warning("MAIL_TYPE is not set") + case _: + raise ValueError("Unsupported mail type {}".format(mail_type)) def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 054769e7ff..e1f8409f21 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -2,6 +2,8 @@ import redis from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel +from configs import dify_config + class RedisClientWrapper(redis.Redis): """ @@ -43,37 +45,37 @@ redis_client = RedisClientWrapper() def init_app(app): global redis_client connection_class = Connection - if app.config.get("REDIS_USE_SSL"): + if dify_config.REDIS_USE_SSL: connection_class = SSLConnection redis_params = { - "username": app.config.get("REDIS_USERNAME"), - "password": app.config.get("REDIS_PASSWORD"), - "db": app.config.get("REDIS_DB"), + "username": dify_config.REDIS_USERNAME, + "password": dify_config.REDIS_PASSWORD, + "db": dify_config.REDIS_DB, "encoding": "utf-8", "encoding_errors": "strict", "decode_responses": False, } - if app.config.get("REDIS_USE_SENTINEL"): + if dify_config.REDIS_USE_SENTINEL: sentinel_hosts = [ - (node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",") + (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") ] sentinel = Sentinel( sentinel_hosts, sentinel_kwargs={ - "socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1), - "username": app.config.get("REDIS_SENTINEL_USERNAME"), - "password": app.config.get("REDIS_SENTINEL_PASSWORD"), + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, }, ) - master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params) + master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) redis_client.initialize(master) else: redis_params.update( { - "host": app.config.get("REDIS_HOST"), - "port": app.config.get("REDIS_PORT"), + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, "connection_class": connection_class, } ) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index e255e7eb35..11f1dd93c6 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,6 +5,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException +from configs import dify_config from core.model_runtime.errors.invoke import InvokeRateLimitError @@ -18,9 +19,9 @@ def before_send(event, hint): def init_app(app): - if app.config.get("SENTRY_DSN"): + if dify_config.SENTRY_DSN: sentry_sdk.init( - dsn=app.config.get("SENTRY_DSN"), + dsn=dify_config.SENTRY_DSN, integrations=[FlaskIntegration(), CeleryIntegration()], ignore_errors=[ HTTPException, @@ -29,9 +30,9 @@ def init_app(app): InvokeRateLimitError, parse_error.defaultErrorResponse, ], - traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), - profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), - environment=app.config.get("DEPLOY_ENV"), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", + traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, + profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, + environment=dify_config.DEPLOY_ENV, + release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}", before_send=before_send, ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 5fc4f88832..50c5d7aebc 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -15,7 +15,7 @@ class Storage: def init_app(self, app: Flask): storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) - self.storage_runner = storage_factory(app=app) + self.storage_runner = storage_factory() @staticmethod def get_storage_factory(storage_type: str) -> type[BaseStorage]: diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index ae6911e945..01c1000e50 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,29 +1,27 @@ from collections.abc import Generator import oss2 as aliyun_s3 -from flask import Flask +from configs import dify_config from extensions.storage.base_storage import BaseStorage class AliyunOssStorage(BaseStorage): """Implementation for Aliyun OSS storage.""" - def __init__(self, app: Flask): - super().__init__(app) - - app_config = self.app.config - self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") - self.folder = app.config.get("ALIYUN_OSS_PATH") + def __init__(self): + super().__init__() + self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME + self.folder = dify_config.ALIYUN_OSS_PATH oss_auth_method = aliyun_s3.Auth region = None - if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": + if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4": oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get("ALIYUN_OSS_REGION") - oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) + region = dify_config.ALIYUN_OSS_REGION + oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY) self.client = aliyun_s3.Bucket( oss_auth, - app_config.get("ALIYUN_OSS_ENDPOINT"), + dify_config.ALIYUN_OSS_ENDPOINT, self.bucket_name, connect_timeout=30, region=region, diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index daea660a49..477507feda 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -2,8 +2,8 @@ from collections.abc import Generator from datetime import datetime, timedelta, timezone from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas -from flask import Flask +from configs import dify_config from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage @@ -11,13 +11,12 @@ from extensions.storage.base_storage import BaseStorage class AzureBlobStorage(BaseStorage): """Implementation for Azure Blob storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") - self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") - self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") - self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") + def __init__(self): + super().__init__() + self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME + self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL + self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME + self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY def save(self, filename, data): client = self._sync_client() diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index c5acff4a9d..cd69439749 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -5,24 +5,23 @@ from collections.abc import Generator from baidubce.auth.bce_credentials import BceCredentials from baidubce.bce_client_configuration import BceClientConfiguration from baidubce.services.bos.bos_client import BosClient -from flask import Flask +from configs import dify_config from extensions.storage.base_storage import BaseStorage class BaiduObsStorage(BaseStorage): """Implementation for Baidu OBS storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("BAIDU_OBS_BUCKET_NAME") + def __init__(self): + super().__init__() + self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME client_config = BceClientConfiguration( credentials=BceCredentials( - access_key_id=app_config.get("BAIDU_OBS_ACCESS_KEY"), - secret_access_key=app_config.get("BAIDU_OBS_SECRET_KEY"), + access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY, + secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY, ), - endpoint=app_config.get("BAIDU_OBS_ENDPOINT"), + endpoint=dify_config.BAIDU_OBS_ENDPOINT, ) self.client = BosClient(config=client_config) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index c3fe9ec82a..50abab8537 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -3,16 +3,12 @@ from abc import ABC, abstractmethod from collections.abc import Generator -from flask import Flask - class BaseStorage(ABC): """Interface for file storage.""" - app = None - - def __init__(self, app: Flask): - self.app = app + def __init__(self): # noqa: B027 + pass @abstractmethod def save(self, filename, data): diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 2d1224fd74..e90392a6ba 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,20 +3,20 @@ import io import json from collections.abc import Generator -from flask import Flask from google.cloud import storage as google_cloud_storage +from configs import dify_config from extensions.storage.base_storage import BaseStorage class GoogleCloudStorage(BaseStorage): """Implementation for Google Cloud storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") - service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME + service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 # if service_account_json_str is empty, use Application Default Credentials if service_account_json_str: service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index dd243d4001..3c443d87ac 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,22 +1,22 @@ from collections.abc import Generator -from flask import Flask from obs import ObsClient +from configs import dify_config from extensions.storage.base_storage import BaseStorage class HuaweiObsStorage(BaseStorage): """Implementation for Huawei OBS storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME") + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME self.client = ObsClient( - access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"), - secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"), - server=app_config.get("HUAWEI_OBS_SERVER"), + access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY, + secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY, + server=dify_config.HUAWEI_OBS_SERVER, ) def save(self, filename, data): diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py index 9308c4d180..e458b3ce8a 100644 --- a/api/extensions/storage/local_fs_storage.py +++ b/api/extensions/storage/local_fs_storage.py @@ -3,19 +3,20 @@ import shutil from collections.abc import Generator from pathlib import Path -from flask import Flask +from flask import current_app +from configs import dify_config from extensions.storage.base_storage import BaseStorage class LocalFsStorage(BaseStorage): """Implementation for local filesystem storage.""" - def __init__(self, app: Flask): - super().__init__(app) - folder = self.app.config.get("STORAGE_LOCAL_PATH") + def __init__(self): + super().__init__() + folder = dify_config.STORAGE_LOCAL_PATH if not os.path.isabs(folder): - folder = os.path.join(app.root_path, folder) + folder = os.path.join(current_app.root_path, folder) self.folder = folder def save(self, filename, data): diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index 5295dbdca2..e4f50b34e9 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -2,24 +2,24 @@ from collections.abc import Generator import boto3 from botocore.exceptions import ClientError -from flask import Flask +from configs import dify_config from extensions.storage.base_storage import BaseStorage class OracleOCIStorage(BaseStorage): """Implementation for Oracle OCI storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("OCI_BUCKET_NAME") + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.OCI_BUCKET_NAME self.client = boto3.client( "s3", - aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), - aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), - endpoint_url=app_config.get("OCI_ENDPOINT"), - region_name=app_config.get("OCI_REGION"), + aws_secret_access_key=dify_config.OCI_SECRET_KEY, + aws_access_key_id=dify_config.OCI_ACCESS_KEY, + endpoint_url=dify_config.OCI_ENDPOINT, + region_name=dify_config.OCI_REGION, ) def save(self, filename, data): diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index c529dce7ad..8fd8e703a1 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,23 +1,23 @@ from collections.abc import Generator -from flask import Flask from qcloud_cos import CosConfig, CosS3Client +from configs import dify_config from extensions.storage.base_storage import BaseStorage class TencentCosStorage(BaseStorage): """Implementation for Tencent Cloud COS storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME config = CosConfig( - Region=app_config.get("TENCENT_COS_REGION"), - SecretId=app_config.get("TENCENT_COS_SECRET_ID"), - SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), - Scheme=app_config.get("TENCENT_COS_SCHEME"), + Region=dify_config.TENCENT_COS_REGION, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, ) self.client = CosS3Client(config) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 1bedcf24c2..389c5630e3 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,23 +1,22 @@ from collections.abc import Generator import tos -from flask import Flask +from configs import dify_config from extensions.storage.base_storage import BaseStorage class VolcengineTosStorage(BaseStorage): """Implementation for Volcengine TOS storage.""" - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME") + def __init__(self): + super().__init__() + self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME self.client = tos.TosClientV2( - ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"), - sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"), - endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"), - region=app_config.get("VOLCENGINE_TOS_REGION"), + ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, + sk=dify_config.VOLCENGINE_TOS_SECRET_KEY, + endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT, + region=dify_config.VOLCENGINE_TOS_REGION, ) def save(self, filename, data): diff --git a/api/libs/helper.py b/api/libs/helper.py index e674d7e84b..7638796508 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -12,9 +12,10 @@ from hashlib import sha256 from typing import Any, Optional, Union from zoneinfo import available_timezones -from flask import Response, current_app, stream_with_context +from flask import Response, stream_with_context from flask_restful import fields +from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.file import helpers as file_helpers from extensions.ext_redis import redis_client @@ -214,7 +215,7 @@ class TokenManager: if additional_data: token_data.update(additional_data) - expiry_minutes = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES"] + expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") token_key = cls._get_token_key(token, token_type) expiry_time = int(expiry_minutes * 60) redis_client.setex(token_key, expiry_time, json.dumps(token_data)) diff --git a/api/libs/login.py b/api/libs/login.py index 7f05eb8404..0ea191a185 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,4 +1,3 @@ -import os from functools import wraps from flask import current_app, g, has_request_context, request @@ -7,6 +6,7 @@ from flask_login.config import EXEMPT_METHODS from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy +from configs import dify_config from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin @@ -52,8 +52,7 @@ def login_required(func): @wraps(func) def decorated_view(*args, **kwargs): auth_header = request.headers.get("Authorization") - admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") - if admin_api_key_enable.lower() == "true": + if dify_config.ADMIN_API_KEY_ENABLE: if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") @@ -61,10 +60,10 @@ def login_required(func): auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - admin_api_key = os.getenv("ADMIN_API_KEY") + admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key: - if os.getenv("ADMIN_API_KEY") == auth_token: + if admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: tenant_account_join = ( @@ -82,7 +81,7 @@ def login_required(func): account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) user_logged_in.send(current_app._get_current_object(), user=_get_user()) - if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): + if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: return current_app.login_manager.unauthorized() diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py index 93065ee95c..32e8c11d19 100644 --- a/api/tests/integration_tests/controllers/app_fixture.py +++ b/api/tests/integration_tests/controllers/app_fixture.py @@ -1,6 +1,7 @@ import pytest from app_factory import create_app +from configs import dify_config mock_user = type( "MockUser", @@ -20,5 +21,5 @@ mock_user = type( @pytest.fixture def app(): app = create_app() - app.config["LOGIN_DISABLED"] = True + dify_config.LOGIN_DISABLED = True return app diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 3f334a3764..545d18044d 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -25,7 +25,7 @@ class VolcengineTosTest: return cls._instance def __init__(self): - self.storage = VolcengineTosStorage(app=Flask(__name__)) + self.storage = VolcengineTosStorage() self.storage.bucket_name = get_example_bucket() self.storage.client = TosClientV2( ak="dify",