diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 8bea30324b..add5822bef 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import ( RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.llm.models import ( + get_model_config, + get_v2_req_params, +) from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException logger = logging.getLogger(__name__) @@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) - - req_params = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).copy() - if credentials.get('context_size'): - req_params['max_prompt_tokens'] = credentials.get('context_size') - if credentials.get('max_tokens'): - req_params['max_new_tokens'] = credentials.get('max_tokens') - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') - if stop: - req_params['stop'] = stop - + req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} - if tools: extra_model_kwargs['tools'] = [ MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools ] - resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) if not stream: @@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): """ used to define customizable model schema """ - max_tokens = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') - if credentials.get('max_tokens'): - max_tokens = int(credentials.get('max_tokens')) + model_config = get_model_config(credentials) + rules = [ ParameterRule( name='temperature', @@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): name='presence_penalty', type=ParameterType.FLOAT, use_template='presence_penalty', - label={ - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, + label=I18nObject( + en_US='Presence Penalty', + zh_Hans= '存在惩罚', + ), min=-2.0, max=2.0, ), @@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): name='frequency_penalty', type=ParameterType.FLOAT, use_template='frequency_penalty', - label={ - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, + label=I18nObject( + en_US= 'Frequency Penalty', + zh_Hans= '频率惩罚', + ), min=-2.0, max=2.0, ), @@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): type=ParameterType.INT, use_template='max_tokens', min=1, - max=max_tokens, + max=model_config.properties.max_tokens, default=512, label=I18nObject( zh_Hans='最大生成长度', @@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): ), ] - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('mode'): - model_properties[ModelPropertyKey.MODE] = credentials.get('mode') - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - - model_features = ModelConfigs.get( - credentials['base_model_name'], {}).get('features', []) - + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value + entity = AIModelEntity( model=model, label=I18nObject( @@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_type=ModelType.LLM, model_properties=model_properties, parameter_rules=rules, - features=model_features, + features=model_config.features, ) return entity diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index 3e5938f3b4..c5e53d8955 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,181 +1,123 @@ +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.model_entities import ModelFeature -ModelConfigs = { - 'Doubao-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Skylark2-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4000, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-8B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-70B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-8k': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 16384, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 65536, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B-Fin': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Mistral-7B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 2048, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - } + +class ModelProperties(BaseModel): + context_size: int + max_tokens: int + mode: LLMMode + +class ModelConfig(BaseModel): + properties: ModelProperties + features: list[ModelFeature] + + +configs: dict[str, ModelConfig] = { + 'Doubao-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Skylark2-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-8B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-70B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-8k': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), + features=[] + ), + 'GLM3-130B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'GLM3-130B-Fin': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Mistral-7B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), + features=[] + ) } + +def get_model_config(credentials: dict)->ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = configs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_tokens=int(credentials.get('max_tokens', 0)), + mode= LLMMode.value_of(credentials.get('mode', 'chat')), + ), + features=[] + ) + return model_configs + + +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None=None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_prompt_tokens'] = model_configs.properties.context_size + req_params['max_new_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_new_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('top_k'): + req_params['top_k'] = model_parameters.get('top_k') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 569f89e975..2d8f972b94 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -1,9 +1,27 @@ +from pydantic import BaseModel + + +class ModelProperties(BaseModel): + context_size: int + max_chunks: int + +class ModelConfig(BaseModel): + properties: ModelProperties + ModelConfigs = { - 'Doubao-embedding': { - 'req_params': {}, - 'model_properties': { - 'context_size': 4096, - 'max_chunks': 1, - } - }, + 'Doubao-embedding': ModelConfig( + properties=ModelProperties(context_size=4096, max_chunks=1) + ), } + +def get_model_config(credentials: dict)->ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = ModelConfigs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_chunks=int(credentials.get('max_chunks', 0)), + ) + ) + return model_configs \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 10b01c0d0d..8ac632369e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import ( RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException @@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): """ generate custom model entities from credentials """ - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - if credentials.get('max_chunks'): - model_properties[ModelPropertyKey.MAX_CHUNKS] = int( - credentials.get('max_chunks', 4096)) + model_config = get_model_config(credentials) + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks entity = AIModelEntity( model=model, label=I18nObject(en_US=model),