mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
add vector field for other vectordb (#7051)
This commit is contained in:
parent
aad02113c6
commit
80c94f02e9
|
@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
|
||||
if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
|
||||
publisher.publish(message=message)
|
||||
elif (hasattr(message.event, 'execution_metadata')
|
||||
and message.event.execution_metadata
|
||||
|
|
|
@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector):
|
|||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
vector=match.metadata.get("vector"),
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
|
|
@ -126,13 +126,14 @@ class MyScaleVector(BaseVector):
|
|||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
sql = f"""
|
||||
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
vector=r['vector'],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
|
|
|
@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector):
|
|||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
||||
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
vector = hit['_source'].get(Field.VECTOR.value)
|
||||
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
|
|
@ -234,16 +234,16 @@ class OracleVector(BaseVector):
|
|||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)]
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata, text = record
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
metadata, text, embedding = record
|
||||
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata="")]
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
|
|
@ -399,7 +399,6 @@ class QdrantVector(BaseVector):
|
|||
document = self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
)
|
||||
document.metadata['vector'] = result.vector
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
@ -418,6 +417,7 @@ class QdrantVector(BaseVector):
|
|||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
vector=scored_point.vector,
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
|
|
|
@ -239,8 +239,7 @@ class WeaviateVector(BaseVector):
|
|||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
if "errors" in result:
|
||||
|
@ -248,7 +247,8 @@ class WeaviateVector(BaseVector):
|
|||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
additional = res.pop('_additional')
|
||||
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
|
|
|
@ -10,6 +10,8 @@ class Document(BaseModel):
|
|||
|
||||
page_content: str
|
||||
|
||||
vector: Optional[list[float]] = None
|
||||
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
|
|
|
@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model.get('reranking_model', None),
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
|
|
|
@ -44,7 +44,8 @@ class HitTestingService:
|
|||
top_k=retrieval_model.get('top_k', 2),
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model.get('reranking_model', None),
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
|
|
Loading…
Reference in New Issue
Block a user