mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
use redis to cache embeddings (#2085)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
dc8a8af117
commit
a3c7c07ecc
|
@ -1,3 +1,5 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@ -5,6 +7,8 @@ import numpy as np
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Embedding
|
from models.dataset import Embedding
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
|
||||||
embedding_queue_indices = []
|
embedding_queue_indices = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||||
|
embedding = redis_client.get(embedding_cache_key)
|
||||||
if embedding:
|
if embedding:
|
||||||
text_embeddings[i] = embedding.get_embedding()
|
redis_client.expire(embedding_cache_key, 3600)
|
||||||
|
text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
embedding_queue_indices.append(i)
|
embedding_queue_indices.append(i)
|
||||||
|
|
||||||
|
@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
|
||||||
hash = helper.generate_text_hash(texts[indice])
|
hash = helper.generate_text_hash(texts[indice])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||||
vector = embedding_results[i]
|
vector = embedding_results[i]
|
||||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||||
text_embeddings[indice] = normalized_embedding
|
text_embeddings[indice] = normalized_embedding
|
||||||
embedding.set_embedding(normalized_embedding)
|
# encode embedding to base64
|
||||||
db.session.add(embedding)
|
embedding_vector = np.array(normalized_embedding)
|
||||||
db.session.commit()
|
vector_bytes = embedding_vector.tobytes()
|
||||||
|
# Transform to Base64
|
||||||
|
encoded_vector = base64.b64encode(vector_bytes)
|
||||||
|
# Transform to string
|
||||||
|
encoded_str = encoded_vector.decode("utf-8")
|
||||||
|
redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
||||||
|
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
continue
|
continue
|
||||||
except:
|
except:
|
||||||
logging.exception('Failed to add embedding to db')
|
logging.exception('Failed to add embedding to redis')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||||
|
embedding = redis_client.get(embedding_cache_key)
|
||||||
if embedding:
|
if embedding:
|
||||||
return embedding.get_embedding()
|
redis_client.expire(embedding_cache_key, 3600)
|
||||||
|
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding_result = self._model_instance.invoke_text_embedding(
|
embedding_result = self._model_instance.invoke_text_embedding(
|
||||||
|
@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
# encode embedding to base64
|
||||||
embedding.set_embedding(embedding_results)
|
embedding_vector = np.array(embedding_results)
|
||||||
db.session.add(embedding)
|
vector_bytes = embedding_vector.tobytes()
|
||||||
db.session.commit()
|
# Transform to Base64
|
||||||
|
encoded_vector = base64.b64encode(vector_bytes)
|
||||||
|
# Transform to string
|
||||||
|
encoded_str = encoded_vector.decode("utf-8")
|
||||||
|
redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
||||||
|
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
except:
|
except:
|
||||||
logging.exception('Failed to add embedding to db')
|
logging.exception('Failed to add embedding to redis')
|
||||||
|
|
||||||
return embedding_results
|
return embedding_results
|
||||||
|
|
Loading…
Reference in New Issue
Block a user