diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a042d30e00..7abcde25b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :return: """ for message in self._queue_manager.listen(): - if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher: + if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher: publisher.publish(message=message) elif (hasattr(message.event, 'execution_metadata') and message.event.execution_metadata diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 442d71293f..7d41512d29 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector): documents = [] for match in response.body.matches.match: if match.score > score_threshold: + metadata = json.loads(match.metadata.get("metadata_")) doc = Document( page_content=match.metadata.get("page_content"), - metadata=json.loads(match.metadata.get("metadata_")), + vector=match.metadata.get("vector"), + metadata=metadata, ) documents.append(doc) return documents diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 241b5a8414..cff9293baa 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -126,13 +126,14 @@ class MyScaleVector(BaseVector): where_str = f"WHERE dist < {1 - score_threshold}" if \ self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" sql = f""" - SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} + SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} """ try: return [ Document( page_content=r["text"], + vector=r['vector'], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index d834e8ce14..c95d202173 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector): docs = [] for hit in response['hits']['hits']: metadata = hit['_source'].get(Field.METADATA_KEY.value) - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + vector = hit['_source'].get(Field.VECTOR.value) + page_content = hit['_source'].get(Field.CONTENT_KEY.value) + doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 4bd09b331d..aa2c6171c3 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -234,16 +234,16 @@ class OracleVector(BaseVector): entities.append(token) with self._get_cursor() as cur: cur.execute( - f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", + f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", [" ACCUM ".join(entities)] ) docs = [] for record in cur: - metadata, text = record - docs.append(Document(page_content=text, metadata=metadata)) + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) return docs else: - return [Document(page_content="", metadata="")] + return [Document(page_content="", metadata={})] return [] def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 77c3f6a271..297bff928e 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -399,7 +399,6 @@ class QdrantVector(BaseVector): document = self._document_from_scored_point( result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value ) - document.metadata['vector'] = result.vector documents.append(document) return documents @@ -418,6 +417,7 @@ class QdrantVector(BaseVector): ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), + vector=scored_point.vector, metadata=scored_point.payload.get(metadata_payload_key) or {}, ) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 87fc5ff158..205fe850c3 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -239,8 +239,7 @@ class WeaviateVector(BaseVector): query_obj = self._client.query.get(collection_name, properties) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) + query_obj = query_obj.with_additional(["vector"]) properties = ['text'] result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() if "errors" in result: @@ -248,7 +247,8 @@ class WeaviateVector(BaseVector): docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - docs.append(Document(page_content=text, metadata=res)) + additional = res.pop('_additional') + docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 7bb675b149..6f3c1c5d34 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -10,6 +10,8 @@ class Document(BaseModel): page_content: str + vector: Optional[list[float]] = None + """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index de8ff7ad38..a7e70af628 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): top_k=self.top_k, score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None), + reranking_model=retrieval_model.get('reranking_model', None) + if retrieval_model['reranking_enable'] else None, reranking_mode=retrieval_model.get('reranking_mode') if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 0e072a3e21..de5f6994b0 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -44,7 +44,8 @@ class HitTestingService: top_k=retrieval_model.get('top_k', 2), score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None), + reranking_model=retrieval_model.get('reranking_model', None) + if retrieval_model['reranking_enable'] else None, reranking_mode=retrieval_model.get('reranking_mode') if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None),