refactor: extract cors configs into dify config and cleanup the config class (#5507)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Bowen Liang 2024-06-25 15:48:02 +08:00 committed by GitHub
parent ec1d3ddee2
commit 2a0f03a511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 51 additions and 55 deletions

View File

@ -77,7 +77,6 @@ jobs:
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
docker/docker-compose.chroma.yaml
docker/docker-compose.oracle.yaml
services: |
weaviate
qdrant
@ -87,7 +86,6 @@ jobs:
pgvecto-rs
pgvector
chroma
oracle
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -24,7 +24,6 @@ from flask_cors import CORS
from werkzeug.exceptions import Unauthorized
from commands import register_commands
from config import Config
# DO NOT REMOVE BELOW
from events import event_handlers
@ -82,7 +81,6 @@ def create_flask_app_with_configs() -> Flask:
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_object(Config())
dify_app.config.from_mapping(DifyConfig().model_dump())
return dify_app
@ -232,7 +230,7 @@ def register_blueprints(app):
app = create_app()
celery = app.extensions["celery"]
if app.config['TESTING']:
if app.config.get('TESTING'):
print("App is running in TESTING mode")

View File

@ -1,42 +0,0 @@
import os
import dotenv
DEFAULTS = {
}
def get_env(key):
return os.environ.get(key, DEFAULTS.get(key))
def get_bool_env(key):
value = get_env(key)
return value.lower() == 'true' if value is not None else False
def get_cors_allow_origins(env, default):
cors_allow_origins = []
if get_env(env):
for origin in get_env(env).split(','):
cors_allow_origins.append(origin)
else:
cors_allow_origins = [default]
return cors_allow_origins
class Config:
"""Application configuration class."""
def __init__(self):
dotenv.load_dotenv()
self.TESTING = False
self.APPLICATION_NAME = "langgenius/dify"
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', get_env('CONSOLE_WEB_URL'))
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'WEB_API_CORS_ALLOW_ORIGINS', '*')

View File

@ -5,6 +5,16 @@ class DeploymentConfig(BaseModel):
"""
Deployment configs
"""
APPLICATION_NAME: str = Field(
description='application name',
default='langgenius/dify',
)
TESTING: bool = Field(
description='',
default=False,
)
EDITION: str = Field(
description='deployment edition',
default='SELF_HOSTED',

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
from configs.feature.hosted_service import HostedServiceConfig
@ -125,6 +125,28 @@ class HttpConfig(BaseModel):
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
default='',
)
@computed_field
@property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
inner_WEB_API_CORS_ALLOW_ORIGINS: Optional[str] = Field(
description='',
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default='*',
)
@computed_field
@property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
class InnerAPIConfig(BaseModel):
"""

View File

@ -1,4 +1,5 @@
import logging
import os
import time
from enum import Enum
from threading import Lock
@ -8,7 +9,6 @@ from httpx import get, post
from pydantic import BaseModel
from yarl import URL
from config import get_env
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
# Code Executor
CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
CODE_EXECUTION_TIMEOUT= (10, 60)

View File

@ -15,6 +15,7 @@ def example_env_file(tmp_path, monkeypatch) -> str:
file_path.write_text(dedent(
"""
CONSOLE_API_URL=https://example.com
CONSOLE_WEB_URL=https://example.com
"""))
return str(file_path)
@ -47,14 +48,13 @@ def test_flask_configs(example_env_file):
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump())
config = flask_app.config
# configs read from dotenv directly
assert config['LOG_LEVEL'] == 'INFO'
# configs read from pydantic-settings
assert config['LOG_LEVEL'] == 'INFO'
assert config['COMMIT_SHA'] == ''
assert config['EDITION'] == 'SELF_HOSTED'
assert config['API_COMPRESSION_ENABLED'] is False
assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0
assert config['TESTING'] == False
# value from env file
assert config['CONSOLE_API_URL'] == 'https://example.com'
@ -71,3 +71,7 @@ def test_flask_configs(example_env_file):
'pool_recycle': 3600,
'pool_size': 30,
}
assert config['CONSOLE_WEB_URL']=='https://example.com'
assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com']
assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*']

View File

@ -1,4 +1,10 @@
#!/bin/bash
set -x
pytest api/tests/integration_tests/vdb/
pytest api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/milvus \
api/tests/integration_tests/vdb/pgvecto_rs \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate \
api/tests/integration_tests/vdb/test_vector_store.py