add vector field for other vectordb (#7051)

This commit is contained in:
Jyong 2024-08-07 17:14:03 +08:00 committed by GitHub
parent aad02113c6
commit 80c94f02e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 23 additions and 14 deletions

View File

@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return: :return:
""" """
for message in self._queue_manager.listen(): for message in self._queue_manager.listen():
if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher: if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
publisher.publish(message=message) publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata') elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata and message.event.execution_metadata

View File

@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector):
documents = [] documents = []
for match in response.body.matches.match: for match in response.body.matches.match:
if match.score > score_threshold: if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
doc = Document( doc = Document(
page_content=match.metadata.get("page_content"), page_content=match.metadata.get("page_content"),
metadata=json.loads(match.metadata.get("metadata_")), vector=match.metadata.get("vector"),
metadata=metadata,
) )
documents.append(doc) documents.append(doc)
return documents return documents

View File

@ -126,13 +126,14 @@ class MyScaleVector(BaseVector):
where_str = f"WHERE dist < {1 - score_threshold}" if \ where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
sql = f""" sql = f"""
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k} {where_str} ORDER BY dist {order.value} LIMIT {top_k}
""" """
try: try:
return [ return [
Document( Document(
page_content=r["text"], page_content=r["text"],
vector=r['vector'],
metadata=r["metadata"], metadata=r["metadata"],
) )
for r in self._client.query(sql).named_results() for r in self._client.query(sql).named_results()

View File

@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector):
docs = [] docs = []
for hit in response['hits']['hits']: for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value) metadata = hit['_source'].get(Field.METADATA_KEY.value)
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) vector = hit['_source'].get(Field.VECTOR.value)
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc) docs.append(doc)
return docs return docs

View File

@ -234,16 +234,16 @@ class OracleVector(BaseVector):
entities.append(token) entities.append(token)
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute( cur.execute(
f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)] [" ACCUM ".join(entities)]
) )
docs = [] docs = []
for record in cur: for record in cur:
metadata, text = record metadata, text, embedding = record
docs.append(Document(page_content=text, metadata=metadata)) docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
return docs return docs
else: else:
return [Document(page_content="", metadata="")] return [Document(page_content="", metadata={})]
return [] return []
def delete(self) -> None: def delete(self) -> None:

View File

@ -399,7 +399,6 @@ class QdrantVector(BaseVector):
document = self._document_from_scored_point( document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
) )
document.metadata['vector'] = result.vector
documents.append(document) documents.append(document)
return documents return documents
@ -418,6 +417,7 @@ class QdrantVector(BaseVector):
) -> Document: ) -> Document:
return Document( return Document(
page_content=scored_point.payload.get(content_payload_key), page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {}, metadata=scored_point.payload.get(metadata_payload_key) or {},
) )

View File

@ -239,8 +239,7 @@ class WeaviateVector(BaseVector):
query_obj = self._client.query.get(collection_name, properties) query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"): if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"): query_obj = query_obj.with_additional(["vector"])
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text'] properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result: if "errors" in result:
@ -248,7 +247,8 @@ class WeaviateVector(BaseVector):
docs = [] docs = []
for res in result["data"]["Get"][collection_name]: for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value) text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res)) additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
return docs return docs
def _default_schema(self, index_name: str) -> dict: def _default_schema(self, index_name: str) -> dict:

View File

@ -10,6 +10,8 @@ class Document(BaseModel):
page_content: str page_content: str
vector: Optional[list[float]] = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other """Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.). documents, etc.).
""" """

View File

@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
top_k=self.top_k, top_k=self.top_k,
score_threshold=retrieval_model.get('score_threshold', .0) score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None, if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None), reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode') reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model', if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None), weights=retrieval_model.get('weights', None),

View File

@ -44,7 +44,8 @@ class HitTestingService:
top_k=retrieval_model.get('top_k', 2), top_k=retrieval_model.get('top_k', 2),
score_threshold=retrieval_model.get('score_threshold', .0) score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None, if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None), reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode') reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model', if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None), weights=retrieval_model.get('weights', None),