mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
add score threshold
This commit is contained in:
parent
a401a73eb7
commit
3f34b7e103
|
@ -67,7 +67,11 @@ class TencentVector(BaseVector):
|
|||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
self.delete()
|
||||
collections = self._db.list_collections()
|
||||
for collection in collections:
|
||||
if collection.collection_name == self._collection_name:
|
||||
self.collection = collection
|
||||
return
|
||||
index_type = None
|
||||
for k, v in enum.IndexType.__members__.items():
|
||||
if k == self._client_config.index_type:
|
||||
|
@ -153,12 +157,13 @@ class TencentVector(BaseVector):
|
|||
limit=kwargs.get('top_k', 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
return self._get_search_res(res)
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res):
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
docs = []
|
||||
if res is None or len(res) == 0:
|
||||
return docs
|
||||
|
@ -167,8 +172,11 @@ class TencentVector(BaseVector):
|
|||
meta = result.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
score = 1 - result.get("score")
|
||||
if score > score_threshold:
|
||||
meta['score'] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user