From 0de224b15345f0cdc46e75338bf5dcce15f266fe Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:09:04 +0800 Subject: [PATCH] fix wrong using of RetrievalMethod Enum (#6345) --- api/controllers/console/datasets/datasets.py | 16 ++++++++-------- api/core/rag/datasource/retrieval_service.py | 8 ++++---- api/core/rag/retrieval/dataset_retrieval.py | 4 ++-- api/core/rag/retrieval/retrival_methods.py | 6 +++--- .../dataset_multi_retriever_tool.py | 2 +- .../dataset_retriever/dataset_retriever_tool.py | 2 +- .../knowledge_retrieval_node.py | 2 +- api/models/dataset.py | 2 +- api/services/dataset_service.py | 4 ++-- api/services/hit_testing_service.py | 2 +- 10 files changed, 24 insertions(+), 24 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70c506bb0e..9166372df5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -545,15 +545,15 @@ class DatasetRetrievalSettingApi(Resource): case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH + RetrievalMethod.SEMANTIC_SEARCH.value ] } case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH, - RetrievalMethod.FULL_TEXT_SEARCH, - RetrievalMethod.HYBRID_SEARCH, + RetrievalMethod.SEMANTIC_SEARCH.value, + RetrievalMethod.FULL_TEXT_SEARCH.value, + RetrievalMethod.HYBRID_SEARCH.value, ] } case _: @@ -569,15 +569,15 @@ class DatasetRetrievalSettingMockApi(Resource): case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH + RetrievalMethod.SEMANTIC_SEARCH.value ] } case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH, - RetrievalMethod.FULL_TEXT_SEARCH, - RetrievalMethod.HYBRID_SEARCH, + RetrievalMethod.SEMANTIC_SEARCH.value, + RetrievalMethod.FULL_TEXT_SEARCH.value, + RetrievalMethod.HYBRID_SEARCH.value, ] } case _: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 623b7a3123..8814c61433 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -11,7 +11,7 @@ from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -86,7 +86,7 @@ class RetrievalService: exception_message = ';\n'.join(exceptions) raise Exception(exception_message) - if retrival_method == RetrievalMethod.HYBRID_SEARCH: + if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( query=query, @@ -142,7 +142,7 @@ class RetrievalService: ) if documents: - if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH: + if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents.extend(data_post_processor.invoke( query=query, @@ -174,7 +174,7 @@ class RetrievalService: top_k=top_k ) if documents: - if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH: + if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents.extend(data_post_processor.invoke( query=query, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ea2a194a68..c1f5e0820c 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -28,7 +28,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -464,7 +464,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/rag/retrieval/retrival_methods.py b/api/core/rag/retrieval/retrival_methods.py index 9b7907013d..12aa28a51c 100644 --- a/api/core/rag/retrieval/retrival_methods.py +++ b/api/core/rag/retrieval/retrival_methods.py @@ -1,15 +1,15 @@ from enum import Enum -class RetrievalMethod(str, Enum): +class RetrievalMethod(Enum): SEMANTIC_SEARCH = 'semantic_search' FULL_TEXT_SEARCH = 'full_text_search' HYBRID_SEARCH = 'hybrid_search' @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH} + return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} @staticmethod def is_support_fulltext_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH} + return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 5b053678f3..eaf58ed5bd 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -14,7 +14,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index de2ce5858a..b1e541b8db 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -8,7 +8,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 9e29bd9ea1..12fe4dfa84 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -22,7 +22,7 @@ from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/models/dataset.py b/api/models/dataset.py index 02d49380bd..af840d26d6 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -117,7 +117,7 @@ class Dataset(db.Model): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3d9f1851b7..84049712d9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -688,7 +688,7 @@ class DocumentService: dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -1059,7 +1059,7 @@ class DocumentService: retrieval_model = document_data['retrieval_model'] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9bcf828712..b83e1d8cb7 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,7 +9,7 @@ from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '',