diff --git a/api/config.py b/api/config.py index 1e6000c8ae..f81527da61 100644 --- a/api/config.py +++ b/api/config.py @@ -47,6 +47,7 @@ DEFAULTS = { 'PDF_PREVIEW': 'True', 'LOG_LEVEL': 'INFO', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', + 'DEFAULT_LLM_PROVIDER': 'openai' } @@ -181,6 +182,10 @@ class Config: # You could disable it for compatibility with certain OpenAPI providers self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') + # For temp use only + # set default LLM provider, default is 'openai', support `azure_openai` + self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') + class CloudEditionConfig(Config): def __init__(self): diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py index bc6b8320af..dc9e9c45f1 100644 --- a/api/controllers/console/workspace/providers.py +++ b/api/controllers/console/workspace/providers.py @@ -82,29 +82,33 @@ class ProviderTokenApi(Resource): args = parser.parse_args() - if not args['token']: - raise ValueError('Token is empty') + if args['token']: + try: + ProviderService.validate_provider_configs( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider), + configs=args['token'] + ) + token_is_valid = True + except ValidateFailedError: + token_is_valid = False - try: - ProviderService.validate_provider_configs( + base64_encrypted_token = ProviderService.get_encrypted_token( tenant=current_user.current_tenant, provider_name=ProviderName(provider), configs=args['token'] ) - token_is_valid = True - except ValidateFailedError: + else: + base64_encrypted_token = None token_is_valid = False tenant = current_user.current_tenant - base64_encrypted_token = ProviderService.get_encrypted_token( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider), - configs=args['token'] - ) - - provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider, - provider_type=ProviderType.CUSTOM.value).first() + provider_model = db.session.query(Provider).filter( + Provider.tenant_id == tenant.id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM.value + ).first() # Only allow updating token for CUSTOM provider type if provider_model: @@ -117,6 +121,16 @@ class ProviderTokenApi(Resource): is_valid=token_is_valid) db.session.add(provider_model) + if provider_model.is_valid: + other_providers = db.session.query(Provider).filter( + Provider.tenant_id == tenant.id, + Provider.provider_name != provider, + Provider.provider_type == ProviderType.CUSTOM.value + ).all() + + for other_provider in other_providers: + other_provider.is_valid = False + db.session.commit() if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py index 0938397423..0f7cb252e2 100644 --- a/api/core/embedding/openai_embedding.py +++ b/api/core/embedding/openai_embedding.py @@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embedding( - text: str, - engine: Optional[str] = None, - openai_api_key: Optional[str] = None, + text: str, + engine: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs ) -> List[float]: """Get embedding. @@ -25,11 +26,12 @@ def get_embedding( """ text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] + return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: +async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ + float]: """Asynchronously get embedding. NOTE: Copied from OpenAI's embedding utils: @@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ + return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ "embedding" ] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( - list_of_text: List[str], - engine: Optional[str] = None, - openai_api_key: Optional[str] = None + list_of_text: List[str], + engine: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs ) -> List[List[float]]: """Get embeddings. @@ -67,14 +70,14 @@ def get_embeddings( # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data + data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embeddings( - list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None + list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs ) -> List[List[float]]: """Asynchronously get embeddings. @@ -90,7 +93,7 @@ async def aget_embeddings( # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @@ -98,19 +101,30 @@ async def aget_embeddings( class OpenAIEmbedding(BaseEmbedding): def __init__( - self, - mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, - model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, - deployment_name: Optional[str] = None, - openai_api_key: Optional[str] = None, - **kwargs: Any, + self, + mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, + model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, + deployment_name: Optional[str] = None, + openai_api_key: Optional[str] = None, + **kwargs: Any, ) -> None: """Init params.""" - super().__init__(**kwargs) + new_kwargs = {} + + if 'embed_batch_size' in kwargs: + new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] + + if 'tokenizer' in kwargs: + new_kwargs['tokenizer'] = kwargs['tokenizer'] + + super().__init__(**new_kwargs) self.mode = OpenAIEmbeddingMode(mode) self.model = OpenAIEmbeddingModelType(model) self.deployment_name = deployment_name self.openai_api_key = openai_api_key + self.openai_api_type = kwargs.get('openai_api_type') + self.openai_api_version = kwargs.get('openai_api_version') + self.openai_api_base = kwargs.get('openai_api_base') @handle_llm_exceptions def _get_query_embedding(self, query: str) -> List[float]: @@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _QUERY_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _QUERY_MODE_MODEL_DICT[key] - return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) + return get_embedding(query, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" @@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + return get_embedding(text, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding.""" @@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings. @@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) return embeddings async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: @@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding): if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, + api_type=self.openai_api_type, api_version=self.openai_api_version, + api_base=self.openai_api_base) return embeddings diff --git a/api/core/index/index_builder.py b/api/core/index/index_builder.py index baf16b0f3a..7f0486546e 100644 --- a/api/core/index/index_builder.py +++ b/api/core/index/index_builder.py @@ -33,8 +33,11 @@ class IndexBuilder: max_chunk_overlap=20 ) + provider = LLMBuilder.get_default_provider(tenant_id) + model_credentials = LLMBuilder.get_model_credentials( tenant_id=tenant_id, + model_provider=provider, model_name='text-embedding-ada-002' ) diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py index 4355593c5d..30b0a931b3 100644 --- a/api/core/llm/llm_builder.py +++ b/api/core/llm/llm_builder.py @@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager from langchain.llms.fake import FakeListLLM from core.constant import llm_constant +from core.llm.error import ProviderTokenNotInitError +from core.llm.provider.base import BaseProvider from core.llm.provider.llm_provider_service import LLMProviderService +from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI +from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_open_ai import StreamableOpenAI +from models.provider import ProviderType class LLMBuilder: @@ -31,16 +36,23 @@ class LLMBuilder: if model_name == 'fake': return FakeListLLM(responses=[]) + provider = cls.get_default_provider(tenant_id) + mode = cls.get_mode_by_model(model_name) if mode == 'chat': - # llm_cls = StreamableAzureChatOpenAI - llm_cls = StreamableChatOpenAI + if provider == 'openai': + llm_cls = StreamableChatOpenAI + else: + llm_cls = StreamableAzureChatOpenAI elif mode == 'completion': - llm_cls = StreamableOpenAI + if provider == 'openai': + llm_cls = StreamableOpenAI + else: + llm_cls = StreamableAzureOpenAI else: raise ValueError(f"model name {model_name} is not supported.") - model_credentials = cls.get_model_credentials(tenant_id, model_name) + model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) return llm_cls( model_name=model_name, @@ -86,18 +98,31 @@ class LLMBuilder: raise ValueError(f"model name {model_name} is not supported.") @classmethod - def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: + def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: """ Returns the API credentials for the given tenant_id and model_name, based on the model's provider. Raises an exception if the model_name is not found or if the provider is not found. """ if not model_name: raise Exception('model name not found') + # + # if model_name not in llm_constant.models: + # raise Exception('model {} not found'.format(model_name)) - if model_name not in llm_constant.models: - raise Exception('model {} not found'.format(model_name)) - - model_provider = llm_constant.models[model_name] + # model_provider = llm_constant.models[model_name] provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) return provider_service.get_credentials(model_name) + + @classmethod + def get_default_provider(cls, tenant_id: str) -> str: + provider = BaseProvider.get_valid_provider(tenant_id) + if not provider: + raise ProviderTokenNotInitError() + + if provider.provider_type == ProviderType.SYSTEM.value: + provider_name = 'openai' + else: + provider_name = provider.provider_name + + return provider_name diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index e0ba0d0734..d68ed3ccc4 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -36,10 +36,9 @@ class AzureProvider(BaseProvider): """ Returns the API credentials for Azure OpenAI as a dictionary. """ - encrypted_config = self.get_provider_api_key(model_id=model_id) - config = json.loads(encrypted_config) + config = self.get_provider_api_key(model_id=model_id) config['openai_api_type'] = 'azure' - config['deployment_name'] = model_id + config['deployment_name'] = model_id.replace('.', '') return config def get_provider_name(self): @@ -51,12 +50,11 @@ class AzureProvider(BaseProvider): """ try: config = self.get_provider_api_key() - config = json.loads(config) except: config = { 'openai_api_type': 'azure', 'openai_api_version': '2023-03-15-preview', - 'openai_api_base': 'https://foo.microsoft.com/bar', + 'openai_api_base': 'https://.openai.azure.com/', 'openai_api_key': '' } @@ -65,7 +63,7 @@ class AzureProvider(BaseProvider): config = { 'openai_api_type': 'azure', 'openai_api_version': '2023-03-15-preview', - 'openai_api_base': 'https://foo.microsoft.com/bar', + 'openai_api_base': 'https://.openai.azure.com/', 'openai_api_key': '' } diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py index 89343ff62a..71bb32dca6 100644 --- a/api/core/llm/provider/base.py +++ b/api/core/llm/provider/base.py @@ -14,7 +14,7 @@ class BaseProvider(ABC): def __init__(self, tenant_id: str): self.tenant_id = tenant_id - def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: + def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: """ Returns the decrypted API key for the given tenant_id and provider_name. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. @@ -43,23 +43,35 @@ class BaseProvider(ABC): Returns the Provider instance for the given tenant_id and provider_name. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. """ - providers = db.session.query(Provider).filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.get_provider_name().value - ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() + return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) + + @classmethod + def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: + """ + Returns the Provider instance for the given tenant_id and provider_name. + If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. + """ + query = db.session.query(Provider).filter( + Provider.tenant_id == tenant_id + ) + + if provider_name: + query = query.filter(Provider.provider_name == provider_name) + + providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() custom_provider = None system_provider = None for provider in providers: - if provider.provider_type == ProviderType.CUSTOM.value: + if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: custom_provider = provider - elif provider.provider_type == ProviderType.SYSTEM.value: + elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: system_provider = provider - if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: + if custom_provider: return custom_provider - elif system_provider and system_provider.is_valid: + elif system_provider: return system_provider else: return None @@ -80,7 +92,7 @@ class BaseProvider(ABC): try: config = self.get_provider_api_key() except: - config = 'THIS-IS-A-MOCK-TOKEN' + config = '' if obfuscated: return self.obfuscated_token(config) diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py index 539ce92774..f3d514cf58 100644 --- a/api/core/llm/streamable_azure_chat_open_ai.py +++ b/api/core/llm/streamable_azure_chat_open_ai.py @@ -1,12 +1,50 @@ -import requests from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.chat_models import AzureChatOpenAI -from typing import Optional, List +from typing import Optional, List, Dict, Any + +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableAzureChatOpenAI(AzureChatOpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + **super()._default_params, + "engine": self.deployment_name, + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + } + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in a list of messages. diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py new file mode 100644 index 0000000000..e383f8cf23 --- /dev/null +++ b/api/core/llm/streamable_azure_open_ai.py @@ -0,0 +1,64 @@ +import os + +from langchain.llms import AzureOpenAI +from langchain.schema import LLMResult +from typing import Optional, List, Dict, Mapping, Any + +from pydantic import root_validator + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +class StreamableAzureOpenAI(AzureOpenAI): + openai_api_type: str = "azure" + openai_api_version: str = "" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + values["client"] = openai.Completion + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") + return values + + @property + def _invocation_params(self) -> Dict[str, Any]: + return {**super()._invocation_params, **{ + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {**super()._identifying_params, **{ + "api_type": self.openai_api_type, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @handle_llm_exceptions + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return super().generate(prompts, stop) + + @handle_llm_exceptions_async + async def agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return await super().agenerate(prompts, stop) diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py index 59391e4ce0..582041ba09 100644 --- a/api/core/llm/streamable_chat_open_ai.py +++ b/api/core/llm/streamable_chat_open_ai.py @@ -1,12 +1,52 @@ +import os + from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.chat_models import ChatOpenAI -from typing import Optional, List +from typing import Optional, List, Dict, Any + +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableChatOpenAI(ChatOpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + **super()._default_params, + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + } + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in a list of messages. diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py index 94754af30e..9cf1b4c4bb 100644 --- a/api/core/llm/streamable_open_ai.py +++ b/api/core/llm/streamable_open_ai.py @@ -1,12 +1,54 @@ +import os + from langchain.schema import LLMResult -from typing import Optional, List +from typing import Optional, List, Dict, Any, Mapping from langchain import OpenAI +from pydantic import root_validator from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async class StreamableOpenAI(OpenAI): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + values["client"] = openai.Completion + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") + return values + + @property + def _invocation_params(self) -> Dict[str, Any]: + return {**super()._invocation_params, **{ + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {**super()._identifying_params, **{ + "api_type": 'openai', + "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), + "api_version": None, + "api_key": self.openai_api_key, + "organization": self.openai_organization if self.openai_organization else None, + }} + + @handle_llm_exceptions def generate( self, prompts: List[str], stop: Optional[List[str]] = None diff --git a/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx b/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx index 71236120e5..5681fd2204 100644 --- a/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx +++ b/web/app/components/header/account-setting/provider-page/azure-provider/index.tsx @@ -20,7 +20,7 @@ const AzureProvider = ({ const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) const handleFocus = () => { if (token === provider.token) { - token.azure_api_key = '' + token.openai_api_key = '' setToken({...token}) onTokenChange({...token}) } @@ -35,31 +35,17 @@ const AzureProvider = ({
handleChange('azure_api_base', v)} - /> - handleChange('azure_api_type', v)} - /> - handleChange('azure_api_version', v)} + name={t('common.provider.azure.apiBase')} + placeholder={t('common.provider.azure.apiBasePlaceholder')} + value={token.openai_api_base} + onChange={(v) => handleChange('openai_api_base', v)} /> handleChange('azure_api_key', v)} + value={token.openai_api_key} + onChange={v => handleChange('openai_api_key', v)} onFocus={handleFocus} onValidatedStatus={onValidatedStatus} providerName={provider.provider_name} @@ -72,4 +58,4 @@ const AzureProvider = ({ ) } -export default AzureProvider \ No newline at end of file +export default AzureProvider diff --git a/web/app/components/header/account-setting/provider-page/provider-item/index.tsx b/web/app/components/header/account-setting/provider-page/provider-item/index.tsx index 4e8ef532e3..6a3cf85846 100644 --- a/web/app/components/header/account-setting/provider-page/provider-item/index.tsx +++ b/web/app/components/header/account-setting/provider-page/provider-item/index.tsx @@ -33,12 +33,12 @@ const ProviderItem = ({ const { notify } = useContext(ToastContext) const [token, setToken] = useState( provider.provider_name === 'azure_openai' - ? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' } + ? { openai_api_base: '', openai_api_key: '' } : '' ) const id = `${provider.provider_name}-${provider.provider_type}` const isOpen = id === activeId - const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token + const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token const comingSoon = false const isValid = provider.is_valid @@ -135,4 +135,4 @@ const ProviderItem = ({ ) } -export default ProviderItem \ No newline at end of file +export default ProviderItem diff --git a/web/i18n/lang/common.en.ts b/web/i18n/lang/common.en.ts index f0106477c2..96304d2677 100644 --- a/web/i18n/lang/common.en.ts +++ b/web/i18n/lang/common.en.ts @@ -148,12 +148,8 @@ const translation = { editKey: 'Edit', invalidApiKey: 'Invalid API key', azure: { - resourceName: 'Resource Name', - resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', - deploymentId: 'Deployment ID', - deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', - apiVersion: 'API Version', - apiVersionPlaceholder: 'The API version to use for this operation.', + apiBase: 'API Base', + apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.', apiKey: 'API Key', apiKeyPlaceholder: 'Enter your API key here', helpTip: 'Learn Azure OpenAI Service', diff --git a/web/i18n/lang/common.zh.ts b/web/i18n/lang/common.zh.ts index f96bdfa89d..496d27ad48 100644 --- a/web/i18n/lang/common.zh.ts +++ b/web/i18n/lang/common.zh.ts @@ -149,14 +149,10 @@ const translation = { editKey: '编辑', invalidApiKey: '无效的 API 密钥', azure: { - resourceName: 'Resource Name', - resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', - deploymentId: 'Deployment ID', - deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', - apiVersion: 'API Version', - apiVersionPlaceholder: 'The API version to use for this operation.', + apiBase: 'API Base', + apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址', apiKey: 'API Key', - apiKeyPlaceholder: 'Enter your API key here', + apiKeyPlaceholder: '输入你的 API 密钥', helpTip: '了解 Azure OpenAI Service', }, openaiHosted: { diff --git a/web/models/common.ts b/web/models/common.ts index 21a74447e1..adce856fd1 100644 --- a/web/models/common.ts +++ b/web/models/common.ts @@ -55,10 +55,8 @@ export type Member = Pick