add vector field for other vectordb (#7051)

This commit is contained in:
Jyong 2024-08-07 17:14:03 +08:00 committed by GitHub
parent aad02113c6
commit 80c94f02e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 23 additions and 14 deletions

View File

@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
for message in self._queue_manager.listen():
if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata

View File

@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector):
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=json.loads(match.metadata.get("metadata_")),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
return documents

View File

@ -126,13 +126,14 @@ class MyScaleVector(BaseVector):
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
sql = f"""
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
"""
try:
return [
Document(
page_content=r["text"],
vector=r['vector'],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()

View File

@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value)
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
vector = hit['_source'].get(Field.VECTOR.value)
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs

View File

@ -234,16 +234,16 @@ class OracleVector(BaseVector):
entities.append(token)
with self._get_cursor() as cur:
cur.execute(
f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)]
)
docs = []
for record in cur:
metadata, text = record
docs.append(Document(page_content=text, metadata=metadata))
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
return docs
else:
return [Document(page_content="", metadata="")]
return [Document(page_content="", metadata={})]
return []
def delete(self) -> None:

View File

@ -399,7 +399,6 @@ class QdrantVector(BaseVector):
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
)
document.metadata['vector'] = result.vector
documents.append(document)
return documents
@ -418,6 +417,7 @@ class QdrantVector(BaseVector):
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {},
)

View File

@ -239,8 +239,7 @@ class WeaviateVector(BaseVector):
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
query_obj = query_obj.with_additional(["vector"])
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result:
@ -248,7 +247,8 @@ class WeaviateVector(BaseVector):
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res))
additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:

View File

@ -10,6 +10,8 @@ class Document(BaseModel):
page_content: str
vector: Optional[list[float]] = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""

View File

@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
top_k=self.top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None),
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),

View File

@ -44,7 +44,8 @@ class HitTestingService:
top_k=retrieval_model.get('top_k', 2),
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None),
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),