use redis to cache embeddings (#2085)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2024-01-18 21:39:12 +08:00 committed by GitHub
parent dc8a8af117
commit a3c7c07ecc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,5 @@
import base64
import json
import logging import logging
from typing import List, Optional from typing import List, Optional
@ -5,6 +7,8 @@ import numpy as np
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from extensions.ext_database import db from extensions.ext_database import db
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from extensions.ext_redis import redis_client
from libs import helper from libs import helper
from models.dataset import Embedding from models.dataset import Embedding
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = [] embedding_queue_indices = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key)
if embedding: if embedding:
text_embeddings[i] = embedding.get_embedding() redis_client.expire(embedding_cache_key, 3600)
text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
else: else:
embedding_queue_indices.append(i) embedding_queue_indices.append(i)
@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[indice]) hash = helper.generate_text_hash(texts[indice])
try: try:
embedding = Embedding(model_name=self._model_instance.model, hash=hash) embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
vector = embedding_results[i] vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
text_embeddings[indice] = normalized_embedding text_embeddings[indice] = normalized_embedding
embedding.set_embedding(normalized_embedding) # encode embedding to base64
db.session.add(embedding) embedding_vector = np.array(normalized_embedding)
db.session.commit() vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
continue continue
except: except:
logging.exception('Failed to add embedding to db') logging.exception('Failed to add embedding to redis')
continue continue
return text_embeddings return text_embeddings
@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key)
if embedding: if embedding:
return embedding.get_embedding() redis_client.expire(embedding_cache_key, 3600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try: try:
embedding_result = self._model_instance.invoke_text_embedding( embedding_result = self._model_instance.invoke_text_embedding(
@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
raise ex raise ex
try: try:
embedding = Embedding(model_name=self._model_instance.model, hash=hash) # encode embedding to base64
embedding.set_embedding(embedding_results) embedding_vector = np.array(embedding_results)
db.session.add(embedding) vector_bytes = embedding_vector.tobytes()
db.session.commit() # Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
except: except:
logging.exception('Failed to add embedding to db') logging.exception('Failed to add embedding to redis')
return embedding_results return embedding_results