From 13fcd7a901b78219fb7ecae5aea250ca70ec6ddc Mon Sep 17 00:00:00 2001 From: Xiaoguang Sun Date: Mon, 24 Jun 2024 14:41:07 +0800 Subject: [PATCH] feat: Add program_name attribute to TiDB connection (#5499) Signed-off-by: Xiaoguang Sun --- api/config.py | 1 + api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py | 6 +++++- .../integration_tests/vdb/tidb_vector/test_tidb_vector.py | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/api/config.py b/api/config.py index 35e8ab5e94..38a8ca31d4 100644 --- a/api/config.py +++ b/api/config.py @@ -33,6 +33,7 @@ class Config: dotenv.load_dotenv() self.TESTING = False + self.APPLICATION_NAME = "langgenius/dify" # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 3a9a56f93a..1da0fd554f 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -26,6 +26,7 @@ class TiDBVectorConfig(BaseModel): user: str password: str database: str + program_name: str @model_validator(mode='before') def validate_config(cls, values: dict) -> dict: @@ -39,6 +40,8 @@ class TiDBVectorConfig(BaseModel): raise ValueError("config TIDB_VECTOR_PASSWORD is required") if not values['database']: raise ValueError("config TIDB_VECTOR_DATABASE is required") + if not values['program_name']: + raise ValueError("config APPLICATION_NAME is required") return values @@ -65,7 +68,7 @@ class TiDBVector(BaseVector): super().__init__(collection_name) self._client_config = config self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true") + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -245,5 +248,6 @@ class TiDBVectorFactory(AbstractVectorFactory): user=config.get('TIDB_VECTOR_USER'), password=config.get('TIDB_VECTOR_PASSWORD'), database=config.get('TIDB_VECTOR_DATABASE'), + program_name=config.get('APPLICATION_NAME'), ), ) \ No newline at end of file diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 837a228a55..7cd8d22e91 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -16,7 +16,8 @@ def tidb_vector(): port="4000", user="xxx.root", password="xxxxxx", - database="dify" + database="dify", + program_name="langgenius/dify" ) )