mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
use redis to cache embeddings (#2085)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
dc8a8af117
commit
a3c7c07ecc
|
@ -1,3 +1,5 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -5,6 +7,8 @@ import numpy as np
|
|||
from core.model_manager import ModelInstance
|
||||
from extensions.ext_database import db
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
redis_client.expire(embedding_cache_key, 3600)
|
||||
text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
|
||||
else:
|
||||
embedding_queue_indices.append(i)
|
||||
|
||||
|
@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
|
|||
hash = helper.generate_text_hash(texts[indice])
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
text_embeddings[indice] = normalized_embedding
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
# encode embedding to base64
|
||||
embedding_vector = np.array(normalized_embedding)
|
||||
vector_bytes = embedding_vector.tobytes()
|
||||
# Transform to Base64
|
||||
encoded_vector = base64.b64encode(vector_bytes)
|
||||
# Transform to string
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
||||
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
continue
|
||||
|
||||
return text_embeddings
|
||||
|
@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
|
|||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
return embedding.get_embedding()
|
||||
redis_client.expire(embedding_cache_key, 3600)
|
||||
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
|
||||
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
|
@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
|
|||
raise ex
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
embedding.set_embedding(embedding_results)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
# encode embedding to base64
|
||||
embedding_vector = np.array(embedding_results)
|
||||
vector_bytes = embedding_vector.tobytes()
|
||||
# Transform to Base64
|
||||
encoded_vector = base64.b64encode(vector_bytes)
|
||||
# Transform to string
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
||||
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
|
||||
return embedding_results
|
||||
|
|
Loading…
Reference in New Issue
Block a user