From 4ab4bcc074fe311fc061b7c2631a7c7457ef9a2d Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 10 Oct 2023 12:09:35 +0800 Subject: [PATCH] feat: support openllm embedding (#1293) --- .../models/embedding/openllm_embedding.py | 22 ++++++ .../models/embedding/xinference_embedding.py | 6 +- .../providers/openllm_provider.py | 25 +++++-- .../langchain/embeddings/openllm_embedding.py | 67 +++++++++++++++++++ .../embedding/test_openllm_embedding.py | 63 +++++++++++++++++ 5 files changed, 171 insertions(+), 12 deletions(-) create mode 100644 api/core/model_providers/models/embedding/openllm_embedding.py create mode 100644 api/core/third_party/langchain/embeddings/openllm_embedding.py create mode 100644 api/tests/integration_tests/models/embedding/test_openllm_embedding.py diff --git a/api/core/model_providers/models/embedding/openllm_embedding.py b/api/core/model_providers/models/embedding/openllm_embedding.py new file mode 100644 index 0000000000..8d27815bfc --- /dev/null +++ b/api/core/model_providers/models/embedding/openllm_embedding.py @@ -0,0 +1,22 @@ +from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.models.embedding.base import BaseEmbedding + + +class OpenLLMEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = OpenLLMEmbeddings( + server_url=credentials['server_url'] + ) + + super().__init__(model_provider, client, name) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"OpenLLM embedding: {str(ex)}") diff --git a/api/core/model_providers/models/embedding/xinference_embedding.py b/api/core/model_providers/models/embedding/xinference_embedding.py index ba8dd2d27d..4cecd511e5 100644 --- a/api/core/model_providers/models/embedding/xinference_embedding.py +++ b/api/core/model_providers/models/embedding/xinference_embedding.py @@ -1,5 +1,4 @@ from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings -from replicate.exceptions import ModelError, ReplicateError from core.model_providers.error import LLMBadRequestError from core.model_providers.providers.base import BaseModelProvider @@ -21,7 +20,4 @@ class XinferenceEmbedding(BaseEmbedding): super().__init__(model_provider, client, name) def handle_exceptions(self, ex: Exception) -> Exception: - if isinstance(ex, (ModelError, ReplicateError)): - return LLMBadRequestError(f"Xinference embedding: {str(ex)}") - else: - return ex + return LLMBadRequestError(f"Xinference embedding: {str(ex)}") diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py index f1274a8082..a691507b9f 100644 --- a/api/core/model_providers/providers/openllm_provider.py +++ b/api/core/model_providers/providers/openllm_provider.py @@ -2,11 +2,13 @@ import json from typing import Type from core.helper import encrypter +from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel +from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings from core.third_party.langchain.llms.openllm import OpenLLM from models.provider import ProviderType @@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider): """ if model_type == ModelType.TEXT_GENERATION: model_class = OpenLLMModel + elif model_type== ModelType.EMBEDDINGS: + model_class = OpenLLMEmbedding else: raise NotImplementedError @@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider): 'server_url': credentials['server_url'] } - llm = OpenLLM( - llm_kwargs={ - 'max_new_tokens': 10 - }, - **credential_kwargs - ) + if model_type == ModelType.TEXT_GENERATION: + llm = OpenLLM( + llm_kwargs={ + 'max_new_tokens': 10 + }, + **credential_kwargs + ) - llm("ping") + llm("ping") + elif model_type == ModelType.EMBEDDINGS: + embedding = OpenLLMEmbeddings( + **credential_kwargs + ) + + embedding.embed_query("ping") except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/third_party/langchain/embeddings/openllm_embedding.py b/api/core/third_party/langchain/embeddings/openllm_embedding.py new file mode 100644 index 0000000000..9c87323b84 --- /dev/null +++ b/api/core/third_party/langchain/embeddings/openllm_embedding.py @@ -0,0 +1,67 @@ +"""Wrapper around OpenLLM embedding models.""" +from typing import Any, List, Optional + +import requests +from pydantic import BaseModel, Extra + +from langchain.embeddings.base import Embeddings + + +class OpenLLMEmbeddings(BaseModel, Embeddings): + """Wrapper around OpenLLM embedding models. + """ + + client: Any #: :meta private: + + server_url: Optional[str] = None + """Optional server URL that currently runs a LLMServer with 'openllm start'.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to OpenLLM's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = [] + for text in texts: + result = self.invoke_embedding(text=text) + embeddings.append(result) + + return [list(map(float, e)) for e in embeddings] + + def invoke_embedding(self, text): + params = [ + text + ] + + headers = {"Content-Type": "application/json"} + response = requests.post( + f'{self.server_url}/v1/embeddings', + headers=headers, + json=params + ) + + if not response.ok: + raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}") + + json_response = response.json() + return json_response[0]["embeddings"][0] + + def embed_query(self, text: str) -> List[float]: + """Call out to OpenLLM's embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/api/tests/integration_tests/models/embedding/test_openllm_embedding.py b/api/tests/integration_tests/models/embedding/test_openllm_embedding.py new file mode 100644 index 0000000000..29c24af6b4 --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_openllm_embedding.py @@ -0,0 +1,63 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.openllm_provider import OpenLLMProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openllm', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_embedding_model(mocker): + model_name = 'facebook/opt-125m' + server_url = os.environ['OPENLLM_SERVER_URL'] + model_provider = OpenLLMProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='openllm', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps({ + 'server_url': server_url + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return OpenLLMEmbedding( + model_provider=model_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) > 0 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) > 0