diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 285e2ba388..94ae1556aa 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,3 +1,5 @@ +import base64 +import json import logging from typing import List, Optional @@ -5,6 +7,8 @@ import numpy as np from core.model_manager import ModelInstance from extensions.ext_database import db from langchain.embeddings.base import Embeddings + +from extensions.ext_redis import redis_client from libs import helper from models.dataset import Embedding from sqlalchemy.exc import IntegrityError @@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() + embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding = redis_client.get(embedding_cache_key) if embedding: - text_embeddings[i] = embedding.get_embedding() + redis_client.expire(embedding_cache_key, 3600) + text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float")) + else: embedding_queue_indices.append(i) @@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings): hash = helper.generate_text_hash(texts[indice]) try: - embedding = Embedding(model_name=self._model_instance.model, hash=hash) + embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' vector = embedding_results[i] normalized_embedding = (vector / np.linalg.norm(vector)).tolist() text_embeddings[indice] = normalized_embedding - embedding.set_embedding(normalized_embedding) - db.session.add(embedding) - db.session.commit() + # encode embedding to base64 + embedding_vector = np.array(normalized_embedding) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 3600, encoded_str) + except IntegrityError: db.session.rollback() continue except: - logging.exception('Failed to add embedding to db') + logging.exception('Failed to add embedding to redis') continue return text_embeddings @@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings): """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() + embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding = redis_client.get(embedding_cache_key) if embedding: - return embedding.get_embedding() + redis_client.expire(embedding_cache_key, 3600) + return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) + try: embedding_result = self._model_instance.invoke_text_embedding( @@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings): raise ex try: - embedding = Embedding(model_name=self._model_instance.model, hash=hash) - embedding.set_embedding(embedding_results) - db.session.add(embedding) - db.session.commit() + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 3600, encoded_str) + except IntegrityError: db.session.rollback() except: - logging.exception('Failed to add embedding to db') + logging.exception('Failed to add embedding to redis') return embedding_results