support ERNIE-4.0-8K-Latest (#5216)

This commit is contained in:
Bin 2024-06-14 18:45:24 +08:00 committed by GitHub
parent 7f44e88eda
commit 0f35d07052
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 17 deletions

View File

@ -0,0 +1,40 @@
model: ernie-4.0-8k-Latest
label:
en_US: Ernie-4.0-8K-Latest
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@ -57,7 +57,7 @@ class BaiduAccessToken:
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
@ -114,7 +114,7 @@ class ErnieMessage:
'role': self.role,
'content': self.content,
}
def __init__(self, content: str, role: str = 'user') -> None:
self.content = content
self.role = role
@ -131,6 +131,7 @@ class ErnieBotModel:
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
@ -157,7 +158,7 @@ class ErnieBotModel:
self.api_key = api_key
self.secret_key = secret_key
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
stop: list[str], user: str) \
-> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
@ -189,7 +190,7 @@ class ErnieBotModel:
if stream:
return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp)
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
@ -234,15 +235,15 @@ class ErnieBotModel:
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]
def _check_parameters(self, model: str, parameters: dict[str, Any],
def _check_parameters(self, model: str, parameters: dict[str, Any],
tools: list[PromptMessageTool], stop: list[str]) -> None:
if model not in self.api_bases:
raise BadRequestError(f'Invalid model: {model}')
# if model not in self.function_calling_supports and tools is not None and len(tools) > 0:
# raise BadRequestError(f'Model {model} does not support calling function.')
# ErnieBot supports function calling, however, there is lots of limitations.
@ -259,32 +260,32 @@ class ErnieBotModel:
for s in stop:
if len(s) > 20:
raise BadRequestError('stop item should not exceed 20 characters.')
def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any],
tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]:
# if model in self.function_calling_supports:
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
return self._build_chat_request_body(model, messages, stream, parameters, stop, user)
def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
parameters: dict[str, Any], tools: list[PromptMessageTool],
parameters: dict[str, Any], tools: list[PromptMessageTool],
stop: list[str], user: str) \
-> dict[str, Any]:
if len(messages) % 2 == 0:
raise BadRequestError('The number of messages should be odd.')
if messages[0].role == 'function':
raise BadRequestError('The first message should be user message.')
"""
TODO: implement function calling
"""
def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
parameters: dict[str, Any], stop: list[str], user: str) \
-> dict[str, Any]:
if len(messages) == 0:
raise BadRequestError('The number of messages should not be zero.')
# check if the first element is system, shift it
system_message = ''
if messages[0].role == 'system':
@ -313,7 +314,7 @@ class ErnieBotModel:
body['system'] = system_message
return body
def _handle_chat_generate_response(self, response: Response) -> ErnieMessage:
data = response.json()
if 'error_code' in data:
@ -349,7 +350,7 @@ class ErnieBotModel:
self._handle_error(code, msg)
except Exception as e:
raise InternalServerError(f'Failed to parse response: {e}')
if line.startswith('data:'):
line = line[5:].strip()
else:
@ -361,7 +362,7 @@ class ErnieBotModel:
data = loads(line)
except Exception as e:
raise InternalServerError(f'Failed to parse response: {e}')
result = data['result']
is_end = data['is_end']
@ -379,4 +380,4 @@ class ErnieBotModel:
yield message
else:
message = ErnieMessage(content=result, role='assistant')
yield message
yield message