mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
fix: embedding init err (#956)
This commit is contained in:
parent
a7c78d2cd2
commit
78d3aa5fcd
|
@ -1,4 +1,4 @@
|
|||
from langchain.embeddings import XinferenceEmbeddings
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
|
@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding):
|
|||
)
|
||||
|
||||
client = XinferenceEmbeddings(
|
||||
**credentials,
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid'],
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
|
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
21
api/core/third_party/langchain/embeddings/xinference_embedding.py
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
vectors = super().embed_documents(texts)
|
||||
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
vector = super().embed_query(text)
|
||||
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
Loading…
Reference in New Issue
Block a user