mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Feat/add triton inference server (#2928)
This commit is contained in:
parent
16af509c46
commit
240a94182e
|
@ -11,6 +11,8 @@
|
|||
- groq
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- xinference
|
||||
- triton_inference_server
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
|
@ -20,7 +22,6 @@
|
|||
- moonshot
|
||||
- jina
|
||||
- chatglm
|
||||
- xinference
|
||||
- yi
|
||||
- openllm
|
||||
- localai
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 78 KiB |
|
@ -0,0 +1,3 @@
|
|||
<svg width="567" height="376" viewBox="0 0 567 376" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M58.0366 161.868C58.0366 161.868 109.261 86.2912 211.538 78.4724V51.053C98.2528 60.1511 0.152344 156.098 0.152344 156.098C0.152344 156.098 55.7148 316.717 211.538 331.426V302.282C97.1876 287.896 58.0366 161.868 58.0366 161.868ZM211.538 244.32V271.013C125.114 255.603 101.125 165.768 101.125 165.768C101.125 165.768 142.621 119.799 211.538 112.345V141.633C211.486 141.633 211.449 141.617 211.406 141.617C175.235 137.276 146.978 171.067 146.978 171.067C146.978 171.067 162.816 227.949 211.538 244.32ZM211.538 0.47998V51.053C214.864 50.7981 218.189 50.5818 221.533 50.468C350.326 46.1273 434.243 156.098 434.243 156.098C434.243 156.098 337.861 273.296 237.448 273.296C228.245 273.296 219.63 272.443 211.538 271.009V302.282C218.695 303.201 225.903 303.667 233.119 303.675C326.56 303.675 394.134 255.954 459.566 199.474C470.415 208.162 514.828 229.299 523.958 238.55C461.745 290.639 316.752 332.626 234.551 332.626C226.627 332.626 219.018 332.148 211.538 331.426V375.369H566.701V0.47998H211.538ZM211.538 112.345V78.4724C214.829 78.2425 218.146 78.0672 221.533 77.9602C314.148 75.0512 374.909 157.548 374.909 157.548C374.909 157.548 309.281 248.693 238.914 248.693C228.787 248.693 219.707 247.065 211.536 244.318V141.631C247.591 145.987 254.848 161.914 276.524 198.049L324.737 157.398C324.737 157.398 289.544 111.243 230.219 111.243C223.768 111.241 217.597 111.696 211.538 112.345Z" fill="#77B900"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.5 KiB |
|
@ -0,0 +1,267 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from httpx import Response, post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
invoke LLM
|
||||
|
||||
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
||||
"""
|
||||
return self._generate(
|
||||
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
||||
tools=tools, stop=stop, stream=stream, user=user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
||||
UserPromptMessage(content='ping')
|
||||
], model_parameters={}, stream=False)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
"""
|
||||
get number of tokens
|
||||
|
||||
cause TritonInference LLM is a customized model, we could net detect which tokenizer to use
|
||||
so we just take the GPT2 tokenizer as default
|
||||
"""
|
||||
return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages))
|
||||
|
||||
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
||||
"""
|
||||
convert prompt message to text
|
||||
"""
|
||||
text = ''
|
||||
for item in message:
|
||||
if isinstance(item, UserPromptMessage):
|
||||
text += f'User: {item.content}'
|
||||
elif isinstance(item, SystemPromptMessage):
|
||||
text += f'System: {item.content}'
|
||||
elif isinstance(item, AssistantPromptMessage):
|
||||
text += f'Assistant: {item.content}'
|
||||
else:
|
||||
raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
|
||||
return text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=int(credentials.get('context_length', 2048)),
|
||||
default=min(512, int(credentials.get('context_length', 2048))),
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = None
|
||||
|
||||
if 'completion_type' in credentials:
|
||||
if credentials['completion_type'] == 'chat':
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
parameter_rules=rules,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)),
|
||||
},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
generate text from LLM
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||
|
||||
if 'stream' in credentials and not bool(credentials['stream']) and stream:
|
||||
raise ValueError(f'stream is not supported by model {model}')
|
||||
|
||||
try:
|
||||
parameters = {}
|
||||
if 'temperature' in model_parameters:
|
||||
parameters['temperature'] = model_parameters['temperature']
|
||||
if 'top_p' in model_parameters:
|
||||
parameters['top_p'] = model_parameters['top_p']
|
||||
if 'top_k' in model_parameters:
|
||||
parameters['top_k'] = model_parameters['top_k']
|
||||
if 'presence_penalty' in model_parameters:
|
||||
parameters['presence_penalty'] = model_parameters['presence_penalty']
|
||||
if 'frequency_penalty' in model_parameters:
|
||||
parameters['frequency_penalty'] = model_parameters['frequency_penalty']
|
||||
|
||||
response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={
|
||||
'text_input': self._convert_prompt_message_to_text(prompt_messages),
|
||||
'max_tokens': model_parameters.get('max_tokens', 512),
|
||||
'parameters': {
|
||||
'stream': False,
|
||||
**parameters
|
||||
},
|
||||
}, timeout=(10, 120))
|
||||
response.raise_for_status()
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}')
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
except Exception as ex:
|
||||
raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}')
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Response) -> LLMResult:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
text = resp.json()['text_output']
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=text
|
||||
),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Response) -> Generator:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
text = resp.json()['text_output']
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=text
|
||||
),
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
ValueError
|
||||
]
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class XinferenceAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
|
@ -0,0 +1,84 @@
|
|||
provider: triton_inference_server
|
||||
label:
|
||||
en_US: Triton Inference Server
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy Triton Inference Server
|
||||
zh_Hans: 如何部署 Triton Inference Server
|
||||
url:
|
||||
en_US: https://github.com/triton-inference-server/server
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
|
||||
en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的上下文大小
|
||||
en_US: Enter the context size
|
||||
default: 2048
|
||||
- variable: completion_type
|
||||
label:
|
||||
zh_Hans: 补全类型
|
||||
en_US: Model type
|
||||
type: select
|
||||
required: true
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的补全类型
|
||||
en_US: Enter the completion type
|
||||
options:
|
||||
- label:
|
||||
zh_Hans: 补全模型
|
||||
en_US: Completion model
|
||||
value: completion
|
||||
- label:
|
||||
zh_Hans: 对话模型
|
||||
en_US: Chat model
|
||||
value: chat
|
||||
- variable: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream output
|
||||
type: select
|
||||
required: true
|
||||
default: true
|
||||
placeholder:
|
||||
zh_Hans: 是否支持流式输出
|
||||
en_US: Whether to support stream output
|
||||
options:
|
||||
- label:
|
||||
zh_Hans: 是
|
||||
en_US: Yes
|
||||
value: true
|
||||
- label:
|
||||
zh_Hans: 否
|
||||
en_US: No
|
||||
value: false
|
Loading…
Reference in New Issue
Block a user