diff --git a/api/core/model_providers/models/embedding/huggingface_embedding.py b/api/core/model_providers/models/embedding/huggingface_embedding.py new file mode 100644 index 0000000000..61af1cede8 --- /dev/null +++ b/api/core/model_providers/models/embedding/huggingface_embedding.py @@ -0,0 +1,22 @@ +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings +from core.model_providers.models.embedding.base import BaseEmbedding + + +class HuggingfaceEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = HuggingfaceHubEmbeddings( + model=name, + **credentials + ) + + super().__init__(model_provider, client, name) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Huggingface embedding: {str(ex)}") diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index 75fffac722..deae4e35df 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -1,5 +1,6 @@ import json from typing import Type +import requests from huggingface_hub import HfApi @@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa from core.model_providers.models.base import BaseProviderModel from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM +from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings +from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding from models.provider import ProviderType +HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' + class HuggingfaceHubProvider(BaseModelProvider): @property @@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider): """ if model_type == ModelType.TEXT_GENERATION: model_class = HuggingfaceHubModel + elif model_type == ModelType.EMBEDDINGS: + model_class = HuggingfaceEmbedding else: raise NotImplementedError @@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider): :param model_type: :param credentials: """ - if model_type != ModelType.TEXT_GENERATION: + if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]: raise NotImplementedError if 'huggingfacehub_api_type' not in credentials \ @@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider): if 'task_type' not in credentials: raise CredentialsValidateFailedError('Task Type must be provided.') - if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): + if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'): raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, ' - 'text-generation, summarization.') + 'text-generation, feature-extraction.') try: - llm = HuggingFaceEndpointLLM( - endpoint_url=credentials['huggingfacehub_endpoint_url'], - task=credentials['task_type'], - model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, - huggingfacehub_api_token=credentials['huggingfacehub_api_token'] - ) - - llm("ping") + if credentials['task_type'] == 'feature-extraction': + cls.check_embedding_valid(credentials, model_name) + else: + cls.check_llm_valid(credentials) except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") else: @@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider): if 'inference' in model_info.cardData and not model_info.cardData['inference']: raise ValueError(f'Inference API has been turned off for this model {model_name}.') - VALID_TASKS = ("text2text-generation", "text-generation", "summarization") + VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction") if model_info.pipeline_tag not in VALID_TASKS: raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {VALID_TASKS}.") except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") + @classmethod + def check_llm_valid(cls, credentials: dict): + llm = HuggingFaceEndpointLLM( + endpoint_url=credentials['huggingfacehub_endpoint_url'], + task=credentials['task_type'], + model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, + huggingfacehub_api_token=credentials['huggingfacehub_api_token'] + ) + + llm("ping") + + @classmethod + def check_embedding_valid(cls, credentials: dict, model_name: str): + + cls.check_endpoint_url_model_repository_name(credentials, model_name) + + embedding_model = HuggingfaceHubEmbeddings( + model=model_name, + **credentials + ) + + embedding_model.embed_query("ping") + + @classmethod + def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str): + try: + url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' + headers = { + 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', + 'Content-Type': 'application/json' + } + + response =requests.get(url=url, headers=headers) + + if response.status_code != 200: + raise ValueError('User Name or Organization Name is invalid.') + + model_repository_name = '' + + for item in response.json().get("items", []): + if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + model_repository_name = item.get("model", {}).get("repository") + break + + if model_repository_name != model_name: + raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + + except Exception as e: + raise ValueError(str(e)) + + @classmethod def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict: diff --git a/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py new file mode 100644 index 0000000000..c5d7d7f8db --- /dev/null +++ b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py @@ -0,0 +1,74 @@ +from typing import Any, Dict, List, Optional +import json +import numpy as np + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from huggingface_hub import InferenceClient + +HOSTED_INFERENCE_API = 'hosted_inference_api' +INFERENCE_ENDPOINTS = 'inference_endpoints' + + +class HuggingfaceHubEmbeddings(BaseModel, Embeddings): + client: Any + model: str + + huggingface_namespace: Optional[str] = None + task_type: Optional[str] = None + huggingfacehub_api_type: Optional[str] = None + huggingfacehub_api_token: Optional[str] = None + huggingfacehub_endpoint_url: Optional[str] = None + + class Config: + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values['huggingfacehub_api_token'] = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) + + values['client'] = InferenceClient(token=values['huggingfacehub_api_token']) + + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + model = '' + + if self.huggingfacehub_api_type == HOSTED_INFERENCE_API: + model = self.model + else: + model = self.huggingfacehub_endpoint_url + + output = self.client.post( + json={ + "inputs": texts, + "options": { + "wait_for_model": False, + "use_cache": False + } + }, model=model) + + embeddings = json.loads(output.decode()) + return self.mean_pooling(embeddings) + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] + + # https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task + # Returned values are a list of floats, or a list of list of floats + # (depending on if you sent a string or a list of string, + # and if the automatic reduction, usually mean_pooling for instance was applied for you or not. + # This should be explained on the model's README.) + def mean_pooling(self, embeddings: List) -> List[float]: + # If automatic reduction by giving model, no need to mean_pooling. + # For example one: List[List[float]] + if not isinstance(embeddings[0][0], list): + return embeddings + + # For example two: List[List[List[float]]], need to mean_pooling. + sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings] + return sentence_embeddings diff --git a/api/core/third_party/langchain/llms/huggingface_hub_llm.py b/api/core/third_party/langchain/llms/huggingface_hub_llm.py index 4e8a2e3446..31a91f5dc5 100644 --- a/api/core/third_party/langchain/llms/huggingface_hub_llm.py +++ b/api/core/third_party/langchain/llms/huggingface_hub_llm.py @@ -16,7 +16,7 @@ class HuggingFaceHubLLM(HuggingFaceHub): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports `text-generation`, `text2text-generation` and `summarization` for now. + Only supports `text-generation`, `text2text-generation` for now. Example: .. code-block:: python diff --git a/api/requirements.txt b/api/requirements.txt index 48766a4458..d616d50972 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -51,4 +51,4 @@ stripe~=5.5.0 pandas==1.5.3 xinference==0.4.2 safetensors==0.3.2 -zhipuai==1.0.7 +zhipuai==1.0.7 \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 3400cfaddb..fed6c12e54 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -14,6 +14,7 @@ REPLICATE_API_TOKEN= # Hugging Face API Key HUGGINGFACE_API_KEY= HUGGINGFACE_ENDPOINT_URL= +HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL= # Minimax Credentials MINIMAX_API_KEY= diff --git a/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py new file mode 100644 index 0000000000..452af9e726 --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py @@ -0,0 +1,136 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding +from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider +from models.provider import Provider, ProviderType, ProviderModel + +DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2' + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='huggingface_hub', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker): + valid_api_key = os.environ['HUGGINGFACE_API_KEY'] + endpoint_url = os.environ['HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'] + model_provider = HuggingfaceHubProvider(provider=get_mock_provider()) + + credentials = { + 'huggingfacehub_api_type': huggingfacehub_api_type, + 'huggingfacehub_api_token': valid_api_key, + 'task_type': 'feature-extraction' + } + + if huggingfacehub_api_type == 'inference_endpoints': + credentials['huggingfacehub_endpoint_url'] = endpoint_url + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='huggingface_hub', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps(credentials), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', + return_value=mock_query) + + return HuggingfaceEmbedding( + model_provider=model_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + DEFAULT_MODEL_NAME, + 'hosted_inference_api', + mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post' + , return_value=bytes(json.dumps([[1, 2, 3], [4, 5, 6]]), 'utf-8')) + + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 3 + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_documents_two(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post' + , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]],[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8')) + + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 3 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + DEFAULT_MODEL_NAME, + 'hosted_inference_api', + mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + + mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post' + , return_value=bytes(json.dumps([[1, 2, 3]]), 'utf-8')) + + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 3 + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_query_two(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + + mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post' + , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8')) + + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 3 \ No newline at end of file diff --git a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx index efcc578592..956185dabf 100644 --- a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx +++ b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx @@ -48,6 +48,15 @@ const config: ProviderConfig = { ] } if (v?.huggingfacehub_api_type === 'inference_endpoints') { + if (v.model_type === 'embeddings') { + return [ + 'huggingfacehub_api_token', + 'huggingface_namespace', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + ] + } return [ 'huggingfacehub_api_token', 'model_name', @@ -68,14 +77,27 @@ const config: ProviderConfig = { ] } if (v?.huggingfacehub_api_type === 'inference_endpoints') { - filteredKeys = [ - 'huggingfacehub_api_type', - 'huggingfacehub_api_token', - 'model_name', - 'huggingfacehub_endpoint_url', - 'task_type', - 'model_type', - ] + if (v.model_type === 'embeddings') { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'huggingface_namespace', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + 'model_type', + ] + } + else { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + 'model_type', + ] + } } return filteredKeys.reduce((prev: FormValue, next: string) => { prev[next] = v?.[next] || '' @@ -83,6 +105,31 @@ const config: ProviderConfig = { }, {}) }, fields: [ + { + type: 'radio', + key: 'model_type', + required: true, + label: { + 'en': 'Model Type', + 'zh-Hans': '模型类型', + }, + options: [ + { + key: 'text-generation', + label: { + 'en': 'Text Generation', + 'zh-Hans': '文本生成', + }, + }, + { + key: 'embeddings', + label: { + 'en': 'Embeddings', + 'zh-Hans': 'Embeddings', + }, + }, + ], + }, { type: 'radio', key: 'huggingfacehub_api_type', @@ -121,6 +168,20 @@ const config: ProviderConfig = { 'zh-Hans': '在此输入您的 Hugging Face Hub API Token', }, }, + { + hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'), + type: 'text', + key: 'huggingface_namespace', + required: true, + label: { + 'en': 'User Name / Organization Name', + 'zh-Hans': '用户名 / 组织名称', + }, + placeholder: { + 'en': 'Enter your User Name / Organization Name here', + 'zh-Hans': '在此输入您的用户名 / 组织名称', + }, + }, { type: 'text', key: 'model_name', @@ -148,7 +209,7 @@ const config: ProviderConfig = { }, }, { - hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api', + hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings', type: 'radio', key: 'task_type', required: true, @@ -173,6 +234,25 @@ const config: ProviderConfig = { }, ], }, + { + hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'), + type: 'radio', + key: 'task_type', + required: true, + label: { + 'en': 'Task', + 'zh-Hans': 'Task', + }, + options: [ + { + key: 'feature-extraction', + label: { + 'en': 'Feature Extraction', + 'zh-Hans': 'Feature Extraction', + }, + }, + ], + }, ], }, } diff --git a/web/app/components/header/account-setting/model-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-page/model-modal/Form.tsx index 6f3838c860..da0f91e838 100644 --- a/web/app/components/header/account-setting/model-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-page/model-modal/Form.tsx @@ -1,7 +1,7 @@ import { useEffect, useState } from 'react' import type { Dispatch, FC, SetStateAction } from 'react' import { useContext } from 'use-context-selector' -import type { Field, FormValue, ProviderConfigModal } from '../declarations' +import { type Field, type FormValue, type ProviderConfigModal, ProviderEnum } from '../declarations' import { useValidate } from '../../key-validator/hooks' import { ValidatingTip } from '../../key-validator/ValidateStatus' import { validateModelProviderFn } from '../utils' @@ -85,10 +85,31 @@ const Form: FC = ({ } const handleFormChange = (k: string, v: string) => { - if (mode === 'edit' && !cleared) + if (mode === 'edit' && !cleared) { handleClear({ [k]: v }) - else - handleMultiFormChange({ ...value, [k]: v }, k) + } + else { + const extraValue: Record = {} + if ( + ( + (k === 'model_type' && v === 'embeddings' && value.huggingfacehub_api_type === 'inference_endpoints') + || (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'embeddings') + ) + && modelModal?.key === ProviderEnum.huggingface_hub + ) + extraValue.task_type = 'feature-extraction' + + if ( + ( + (k === 'model_type' && v === 'text-generation' && value.huggingfacehub_api_type === 'inference_endpoints') + || (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'text-generation') + ) + && modelModal?.key === ProviderEnum.huggingface_hub + ) + extraValue.task_type = 'text-generation' + + handleMultiFormChange({ ...value, [k]: v, ...extraValue }, k) + } } const handleFocus = () => { diff --git a/web/app/components/header/account-setting/model-page/model-modal/index.tsx b/web/app/components/header/account-setting/model-page/model-modal/index.tsx index 57426ec60a..385e064ffb 100644 --- a/web/app/components/header/account-setting/model-page/model-modal/index.tsx +++ b/web/app/components/header/account-setting/model-page/model-modal/index.tsx @@ -92,7 +92,7 @@ const ModelModal: FC = ({ return (
-
+
{renderTitlePrefix()}