mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
add vector field for other vectordb (#7051)
This commit is contained in:
parent
aad02113c6
commit
80c94f02e9
|
@ -244,7 +244,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for message in self._queue_manager.listen():
|
for message in self._queue_manager.listen():
|
||||||
if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
|
if message.event and hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher:
|
||||||
publisher.publish(message=message)
|
publisher.publish(message=message)
|
||||||
elif (hasattr(message.event, 'execution_metadata')
|
elif (hasattr(message.event, 'execution_metadata')
|
||||||
and message.event.execution_metadata
|
and message.event.execution_metadata
|
||||||
|
|
|
@ -285,9 +285,11 @@ class AnalyticdbVector(BaseVector):
|
||||||
documents = []
|
documents = []
|
||||||
for match in response.body.matches.match:
|
for match in response.body.matches.match:
|
||||||
if match.score > score_threshold:
|
if match.score > score_threshold:
|
||||||
|
metadata = json.loads(match.metadata.get("metadata_"))
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=match.metadata.get("page_content"),
|
page_content=match.metadata.get("page_content"),
|
||||||
metadata=json.loads(match.metadata.get("metadata_")),
|
vector=match.metadata.get("vector"),
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
return documents
|
return documents
|
||||||
|
|
|
@ -126,13 +126,14 @@ class MyScaleVector(BaseVector):
|
||||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=r["text"],
|
page_content=r["text"],
|
||||||
|
vector=r['vector'],
|
||||||
metadata=r["metadata"],
|
metadata=r["metadata"],
|
||||||
)
|
)
|
||||||
for r in self._client.query(sql).named_results()
|
for r in self._client.query(sql).named_results()
|
||||||
|
|
|
@ -192,7 +192,9 @@ class OpenSearchVector(BaseVector):
|
||||||
docs = []
|
docs = []
|
||||||
for hit in response['hits']['hits']:
|
for hit in response['hits']['hits']:
|
||||||
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
||||||
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
|
vector = hit['_source'].get(Field.VECTOR.value)
|
||||||
|
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
|
||||||
|
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
|
@ -234,16 +234,16 @@ class OracleVector(BaseVector):
|
||||||
entities.append(token)
|
entities.append(token)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||||
[" ACCUM ".join(entities)]
|
[" ACCUM ".join(entities)]
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
for record in cur:
|
for record in cur:
|
||||||
metadata, text = record
|
metadata, text, embedding = record
|
||||||
docs.append(Document(page_content=text, metadata=metadata))
|
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
else:
|
else:
|
||||||
return [Document(page_content="", metadata="")]
|
return [Document(page_content="", metadata={})]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
|
|
@ -399,7 +399,6 @@ class QdrantVector(BaseVector):
|
||||||
document = self._document_from_scored_point(
|
document = self._document_from_scored_point(
|
||||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||||
)
|
)
|
||||||
document.metadata['vector'] = result.vector
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
@ -418,6 +417,7 @@ class QdrantVector(BaseVector):
|
||||||
) -> Document:
|
) -> Document:
|
||||||
return Document(
|
return Document(
|
||||||
page_content=scored_point.payload.get(content_payload_key),
|
page_content=scored_point.payload.get(content_payload_key),
|
||||||
|
vector=scored_point.vector,
|
||||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -239,8 +239,7 @@ class WeaviateVector(BaseVector):
|
||||||
query_obj = self._client.query.get(collection_name, properties)
|
query_obj = self._client.query.get(collection_name, properties)
|
||||||
if kwargs.get("where_filter"):
|
if kwargs.get("where_filter"):
|
||||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||||
if kwargs.get("additional"):
|
query_obj = query_obj.with_additional(["vector"])
|
||||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
|
||||||
properties = ['text']
|
properties = ['text']
|
||||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
|
@ -248,7 +247,8 @@ class WeaviateVector(BaseVector):
|
||||||
docs = []
|
docs = []
|
||||||
for res in result["data"]["Get"][collection_name]:
|
for res in result["data"]["Get"][collection_name]:
|
||||||
text = res.pop(Field.TEXT_KEY.value)
|
text = res.pop(Field.TEXT_KEY.value)
|
||||||
docs.append(Document(page_content=text, metadata=res))
|
additional = res.pop('_additional')
|
||||||
|
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _default_schema(self, index_name: str) -> dict:
|
def _default_schema(self, index_name: str) -> dict:
|
||||||
|
|
|
@ -10,6 +10,8 @@ class Document(BaseModel):
|
||||||
|
|
||||||
page_content: str
|
page_content: str
|
||||||
|
|
||||||
|
vector: Optional[list[float]] = None
|
||||||
|
|
||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
documents, etc.).
|
documents, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model.get('reranking_model', None),
|
reranking_model=retrieval_model.get('reranking_model', None)
|
||||||
|
if retrieval_model['reranking_enable'] else None,
|
||||||
reranking_mode=retrieval_model.get('reranking_mode')
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
|
|
|
@ -44,7 +44,8 @@ class HitTestingService:
|
||||||
top_k=retrieval_model.get('top_k', 2),
|
top_k=retrieval_model.get('top_k', 2),
|
||||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model.get('reranking_model', None),
|
reranking_model=retrieval_model.get('reranking_model', None)
|
||||||
|
if retrieval_model['reranking_enable'] else None,
|
||||||
reranking_mode=retrieval_model.get('reranking_mode')
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user