From 2fd56cb01c48f088724a6311367e8d65c658e11e Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 18 Dec 2023 21:33:54 +0800 Subject: [PATCH] Fix/vdb index issue (#1776) Co-authored-by: jyong --- api/controllers/console/app/annotation.py | 1 - api/controllers/console/wraps.py | 2 +- api/core/completion.py | 3 ++- api/core/index/vector_index/milvus_vector_index.py | 1 - api/core/index/vector_index/vector_index.py | 14 ++++++++++---- .../index/vector_index/weaviate_vector_index.py | 5 +++-- .../annotation/add_annotation_to_index_task.py | 2 ++ 7 files changed, 18 insertions(+), 10 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index a21a2eae64..af7f52e970 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -188,7 +188,6 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') def delete(self, app_id, annotation_id): # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3230067756..19a5de69ed 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -65,7 +65,7 @@ def cloud_edition_billing_resource_check(resource: str, abort(403, error_msg) elif resource == 'workspace_custom' and not billing_info['can_replace_logo']: abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']: + elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']: abort(403, error_msg) else: return view(*args, **kwargs) diff --git a/api/core/completion.py b/api/core/completion.py index b25222ec7a..d219814db6 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -371,7 +371,8 @@ class Completion: vector_index = VectorIndex( dataset=dataset, config=current_app.config, - embeddings=embeddings + embeddings=embeddings, + attributes=['doc_id', 'annotation_id', 'app_id'] ) documents = vector_index.search( diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index 69fbc6beef..c175e5babb 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex): """Only for created index.""" if self._vector_store: return self._vector_store - attributes = ['doc_id', 'dataset_id', 'document_id'] return MilvusVectorStore( collection_name=self.get_index_name(self.dataset), diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py index dd3ab272e0..614f23a291 100644 --- a/api/core/index/vector_index/vector_index.py +++ b/api/core/index/vector_index/vector_index.py @@ -9,12 +9,17 @@ from models.dataset import Dataset, Document class VectorIndex: - def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings): + def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings, + attributes: list = None): + if attributes is None: + attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] self._dataset = dataset self._embeddings = embeddings - self._vector_index = self._init_vector_index(dataset, config, embeddings) + self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes) + self._attributes = attributes - def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex: + def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings, + attributes: list) -> BaseVectorIndex: vector_type = config.get('VECTOR_STORE') if self._dataset.index_struct_dict: @@ -33,7 +38,8 @@ class VectorIndex: api_key=config.get('WEAVIATE_API_KEY'), batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) ), - embeddings=embeddings + embeddings=embeddings, + attributes=attributes ) elif vector_type == "qdrant": from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 0d51ab8d90..0ba8a20bca 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel): class WeaviateVectorIndex(BaseVectorIndex): - def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): + def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list): super().__init__(dataset, embeddings) self._client = self._init_client(config) + self._attributes = attributes def _init_client(self, config: WeaviateConfig) -> weaviate.Client: auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) @@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex): if self._vector_store: return self._vector_store - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + attributes = self._attributes if self._is_origin(): attributes = ['doc_id'] diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 84d94f39ca..620413dffb 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -36,6 +36,8 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s id=app_id, tenant_id=tenant_id, indexing_technique='high_quality', + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id )