From 0f35d07052a319ec8768642e877dd067f63c8154 Mon Sep 17 00:00:00 2001 From: Bin Date: Fri, 14 Jun 2024 18:45:24 +0800 Subject: [PATCH] support ERNIE-4.0-8K-Latest (#5216) --- .../wenxin/llm/ernie-4.0-8k-latest.yaml | 40 +++++++++++++++++++ .../model_providers/wenxin/llm/ernie_bot.py | 35 ++++++++-------- 2 files changed, 58 insertions(+), 17 deletions(-) create mode 100644 api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml new file mode 100644 index 0000000000..50c82564f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml @@ -0,0 +1,40 @@ +model: ernie-4.0-8k-Latest +label: + en_US: Ernie-4.0-8K-Latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 2048 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + use_template: response_format + - name: disable_search + label: + zh_Hans: 禁用搜索 + en_US: Disable Search + type: boolean + help: + zh_Hans: 禁用模型自行进行外部搜索。 + en_US: Disable the model to perform external search. + required: false diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 4646ba384a..305769c1c1 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -57,7 +57,7 @@ class BaiduAccessToken: raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') else: raise Exception(f'Unknown error: {resp["error_description"]}') - + return resp['access_token'] @staticmethod @@ -114,7 +114,7 @@ class ErnieMessage: 'role': self.role, 'content': self.content, } - + def __init__(self, content: str, role: str = 'user') -> None: self.content = content self.role = role @@ -131,6 +131,7 @@ class ErnieBotModel: 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', @@ -157,7 +158,7 @@ class ErnieBotModel: self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], + def generate(self, model: str, stream: bool, messages: list[ErnieMessage], parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ stop: list[str], user: str) \ -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: @@ -189,7 +190,7 @@ class ErnieBotModel: if stream: return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - + def _handle_error(self, code: int, msg: str): error_map = { 1: InternalServerError, @@ -234,15 +235,15 @@ class ErnieBotModel: def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) return token.access_token - + def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], + def _check_parameters(self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str]) -> None: if model not in self.api_bases: raise BadRequestError(f'Invalid model: {model}') - + # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') # ErnieBot supports function calling, however, there is lots of limitations. @@ -259,32 +260,32 @@ class ErnieBotModel: for s in stop: if len(s) > 20: raise BadRequestError('stop item should not exceed 20 characters.') - + def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - + def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], + parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str], user: str) \ -> dict[str, Any]: if len(messages) % 2 == 0: raise BadRequestError('The number of messages should be odd.') if messages[0].role == 'function': raise BadRequestError('The first message should be user message.') - + """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], stop: list[str], user: str) \ -> dict[str, Any]: if len(messages) == 0: raise BadRequestError('The number of messages should not be zero.') - + # check if the first element is system, shift it system_message = '' if messages[0].role == 'system': @@ -313,7 +314,7 @@ class ErnieBotModel: body['system'] = system_message return body - + def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() if 'error_code' in data: @@ -349,7 +350,7 @@ class ErnieBotModel: self._handle_error(code, msg) except Exception as e: raise InternalServerError(f'Failed to parse response: {e}') - + if line.startswith('data:'): line = line[5:].strip() else: @@ -361,7 +362,7 @@ class ErnieBotModel: data = loads(line) except Exception as e: raise InternalServerError(f'Failed to parse response: {e}') - + result = data['result'] is_end = data['is_end'] @@ -379,4 +380,4 @@ class ErnieBotModel: yield message else: message = ErnieMessage(content=result, role='assistant') - yield message \ No newline at end of file + yield message