chore: optimize ark model parameters (#7378)

This commit is contained in:
sino 2024-08-19 08:44:19 +08:00 committed by GitHub
parent 6cd8ab0cbc
commit a0a67873aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 246 deletions

View File

@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
req_params = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).copy()
if credentials.get('context_size'):
req_params['max_prompt_tokens'] = credentials.get('context_size')
if credentials.get('max_tokens'):
req_params['max_new_tokens'] = credentials.get('max_tokens')
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
req_params = get_v2_req_params(credentials, model_parameters, stop)
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
"""
used to define customizable model schema
"""
max_tokens = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
if credentials.get('max_tokens'):
max_tokens = int(credentials.get('max_tokens'))
model_config = get_model_config(credentials)
rules = [
ParameterRule(
name='temperature',
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='presence_penalty',
type=ParameterType.FLOAT,
use_template='presence_penalty',
label={
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
},
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
),
min=-2.0,
max=2.0,
),
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='frequency_penalty',
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label={
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
},
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
),
min=-2.0,
max=2.0,
),
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=max_tokens,
max=model_config.properties.max_tokens,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
),
]
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('mode'):
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
entity = AIModelEntity(
model=model,
label=I18nObject(
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules,
features=model_features,
features=model_config.features,
)
return entity

View File

@ -1,181 +1,123 @@
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4000,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-8k': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 16384,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 65536,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B-Fin': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Mistral-7B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 2048,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
}
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
configs: dict[str, ModelConfig] = {
'Doubao-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Llama3-70B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-8k': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B-Fin': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Mistral-7B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[]
)
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_prompt_tokens'] = model_configs.properties.context_size
req_params['max_new_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -1,9 +1,27 @@
from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
),
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs

View File

@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
"""
generate custom model entities from credentials
"""
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),